Skip to content

Commit

Permalink
Do not transform empty parameter dict in SearchSpaceToChoice (#2789)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2789

Do not transform empty parameter dict in SearchSpaceToChoice. Before this fix, this may inappropriately transforming an empty dict into a arm signature, which leads to unexpected/broken behaviors of models such as EB.

Reviewed By: sdaulton

Differential Revision: D63471431

fbshipit-source-id: a68a7d8973cc89b15da6975114322339f40bae9d
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Sep 26, 2024
1 parent 75b4bf8 commit a144287
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
14 changes: 9 additions & 5 deletions ax/modelbridge/transforms/search_space_to_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,19 @@ def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
obsf.parameters = {
self.parameter_name: Arm(parameters=obsf.parameters).signature
}
# if obsf.parameters is not an empty dict
if len(obsf.parameters) != 0:
obsf.parameters = {
self.parameter_name: Arm(parameters=obsf.parameters).signature
}
return observation_features

def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
signature = obsf.parameters[self.parameter_name]
obsf.parameters = self.signature_to_parameterization[signature]
# Do not untransform empty dict as it wasn't transformed in the first place
if len(obsf.parameters) != 0:
signature = obsf.parameters[self.parameter_name]
obsf.parameters = self.signature_to_parameterization[signature]
return observation_features
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def test_TransformObservationFeatures(self) -> None:
obs_ft2 = self.t.untransform_observation_features(obs_ft2)
self.assertEqual(obs_ft2, self.observation_features)

# Testing transform empty parameters dict
# Both transform and untransform should leave the param dict intact
empty_obs_param = ObservationFeatures(parameters={}, trial_index=0)
tsfm_empty_obs_param = self.t.transform_observation_features([empty_obs_param])[
0
]
self.assertEqual(tsfm_empty_obs_param, empty_obs_param)
untsfm_empty_obs_param = self.t.untransform_observation_features(
[tsfm_empty_obs_param]
)[0]
self.assertEqual(untsfm_empty_obs_param, empty_obs_param)

def test_w_robust_search_space(self) -> None:
rss = get_robust_search_space()
# Raises an error in __init__.
Expand Down

0 comments on commit a144287

Please sign in to comment.