Skip to content

Commit

Permalink
Include progression information as metadata when transforming Data to…
Browse files Browse the repository at this point in the history
… Observations

Differential Revision: D65255312
  • Loading branch information
ltiao authored and facebook-github-bot committed Oct 31, 2024
1 parent 2c2241e commit ceb2293
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def observations_from_data(
data: Data,
statuses_to_include: set[TrialStatus] | None = None,
statuses_to_include_map_metric: set[TrialStatus] | None = None,
map_keys_as_parameters: bool = False,
) -> list[Observation]:
"""Convert Data to observations.
Expand All @@ -455,46 +456,55 @@ def observations_from_data(
with statuses in this set. Defaults to all statuses except abandoned.
statuses_to_include_map_metric: data from MapMetrics will only be included for
trials with statuses in this set. Defaults to completed status only.
map_keys_as_parameters: Whether map_keys should be returned as part of
the parameters of the Observation objects.
Returns:
List of Observation objects.
"""
is_map_data = isinstance(data, MapData)
df = data.df if not is_map_data else data.map_df

if statuses_to_include is None:
statuses_to_include = NON_ABANDONED_STATUSES
if statuses_to_include_map_metric is None:
statuses_to_include_map_metric = {TrialStatus.COMPLETED}
feature_cols = get_feature_cols(data)
observations = []

feature_cols = get_feature_cols(data, is_map_data=is_map_data)

arm_name_only = len(feature_cols) == 1 # there will always be an arm name
# One DataFrame where all rows have all features.
isnull = data.df[feature_cols].isnull()
isnull = df[feature_cols].isnull()
isnull_any = isnull.any(axis=1)
incomplete_df_cols = isnull[isnull_any].any()

# Get the incomplete_df columns that are complete, and usable as groupby keys.
obs_cols = OBS_COLS if not is_map_data else OBS_COLS.union(data.map_keys)
complete_feature_cols = list(
OBS_COLS.intersection(incomplete_df_cols.index[~incomplete_df_cols])
obs_cols.intersection(incomplete_df_cols.index[~incomplete_df_cols])
)

if set(feature_cols) == set(complete_feature_cols):
complete_df = data.df
complete_df = df
incomplete_df = None
else:
# The groupby and filter is expensive, so do it only if we have to.
grouped = data.df.groupby(by=complete_feature_cols)
grouped = df.groupby(by=complete_feature_cols)
complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any())
incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any())

# Get Observations from complete_df
observations = []
observations.extend(
_observations_from_dataframe(
experiment=experiment,
df=complete_df,
cols=feature_cols,
arm_name_only=arm_name_only,
map_keys=[] if not is_map_data else data.map_keys,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys=[],
map_keys_as_parameters=map_keys_as_parameters,
)
)
if incomplete_df is not None:
Expand All @@ -505,9 +515,10 @@ def observations_from_data(
df=incomplete_df,
cols=complete_feature_cols,
arm_name_only=arm_name_only,
map_keys=[] if not is_map_data else data.map_keys,
statuses_to_include=statuses_to_include,
statuses_to_include_map_metric=statuses_to_include_map_metric,
map_keys=[],
map_keys_as_parameters=map_keys_as_parameters,
)
)
return observations
Expand Down

0 comments on commit ceb2293

Please sign in to comment.