Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add storage for encoding/decoding should_skip #2905

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
29 changes: 19 additions & 10 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGen,
MaxGenerationParallelism,
MaxTrials,
MinTrials,
Expand Down Expand Up @@ -96,6 +97,8 @@ class GenerationNode(SerializationMixin, SortableBase):
set during transition from one ``GenerationNode`` to the next. Can be
overwritten if multiple transitions occur between nodes, and will always
store the most recent previous ``GenerationNode`` name.
should_skip: Whether to skip this node during generation time. Defaults to
False, and can only currently be set to True via ``NodeInputConstructors``

Note for developers: by "model" here we really mean an Ax ModelBridge object, which
contains an Ax Model under the hood. We call it "model" here to simplify and focus
Expand All @@ -118,6 +121,7 @@ class GenerationNode(SerializationMixin, SortableBase):
]
_previous_node_name: str | None = None
_trial_type: str | None = None
_should_skip: bool = False

# [TODO] Handle experiment passing more eloquently by enforcing experiment
# attribute is set in generation strategies class
Expand All @@ -141,6 +145,7 @@ def __init__(
) = None,
previous_node_name: str | None = None,
trial_type: str | None = None,
should_skip: bool = False,
) -> None:
self._node_name = node_name
# Check that the model specs have unique model keys.
Expand Down Expand Up @@ -169,6 +174,7 @@ def __init__(
)
self._previous_node_name = previous_node_name
self._trial_type = trial_type
self._should_skip = should_skip

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -237,7 +243,14 @@ def is_completed(self) -> bool:
"""Returns True if this GenerationNode is complete and should transition to
the next node.
"""
return self.should_transition_to_next_node(raise_data_required_error=False)[0]
# TODO: @mgarrard this logic more robust and general
# We won't mark a node completed if it has an AutoTransitionAfterGen criterion
# as this is typically used in cyclic generation strategies
return self.should_transition_to_next_node(raise_data_required_error=False)[
0
] and not any(
isinstance(tc, AutoTransitionAfterGen) for tc in self.transition_criteria
)

@property
def previous_node(self) -> GenerationNode | None:
Expand Down Expand Up @@ -575,16 +588,10 @@ def should_transition_to_next_node(
# transition to the next node defined by that edge.
for next_node, all_tc in self.transition_edges.items():
transition_blocking = [tc for tc in all_tc if tc.block_transition_if_unmet]
gs_lgr = self.generation_strategy.last_generator_run
transition_blocking_met = all(
tc.is_met(
experiment=self.experiment,
trials_from_node=self.trials_from_node,
curr_node_name=self.node_name,
# TODO @mgarrard: should we instead pass a backpointer to gs/node
node_that_generated_last_gr=(
gs_lgr._generation_node_name if gs_lgr is not None else None
),
curr_node=self,
)
for tc in transition_blocking
)
Expand All @@ -600,7 +607,8 @@ def should_transition_to_next_node(
for tc in all_tc:
if (
tc.is_met(
self.experiment, trials_from_node=self.trials_from_node
self.experiment,
curr_node=self,
)
and raise_data_required_error
):
Expand Down Expand Up @@ -647,7 +655,8 @@ def generator_run_limit(self, raise_generation_errors: bool = False) -> int:
# TODO[mgarrard]: Raise a group of all the errors, from each gen-
# blocking transition criterion.
if criterion.is_met(
self.experiment, trials_from_node=self.trials_from_node
self.experiment,
curr_node=self,
):
criterion.block_continued_generation_error(
node_name=self.node_name,
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def repeat_arm_n(
if total_n < 6:
# if the next trial is small, we don't want to waste allocation on repeat arms
# users can still manually add repeat arms if they want before allocation
# and we need to designated this node as skipped for proper transition
next_node._should_skip = True
return 0
elif total_n <= 10:
return 1
Expand Down
45 changes: 28 additions & 17 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ def gen_with_multiple_nodes(
node_to_gen_from = self.nodes_dict[node_to_gen_from_name]
if should_transition:
node_to_gen_from._previous_node_name = node_to_gen_from_name
# reset should skip as conditions may have changed, do not reset
# until now so node properites can be as up to date as possible
node_to_gen_from._should_skip = False
arms_from_node = self._determine_arms_from_node(
node_to_gen_from=node_to_gen_from,
arms_per_node=arms_per_node,
Expand All @@ -467,25 +470,32 @@ def gen_with_multiple_nodes(
gen_kwargs=gen_kwargs,
passed_fixed_features=fixed_features,
)
grs.extend(
self._gen_multiple(
# TODO: @mgarrard clean this up after gens merge. This is currently needed
# because the actual transition occurs in gs.gen(), but if a node is
# skipped, we need to transition here to actually initiate that transition
if node_to_gen_from._should_skip:
self._maybe_transition_to_next_node()
continue
if arms_from_node != 0:
grs.extend(
self._gen_multiple(
experiment=experiment,
num_generator_runs=1,
data=data,
n=arms_from_node,
pending_observations=pending_observations,
fixed_features=fixed_features_from_node,
)
)
# ensure that the points generated from each node are marked as pending
# points for future calls to gen
pending_observations = extend_pending_observations(
experiment=experiment,
num_generator_runs=1,
data=data,
n=arms_from_node,
pending_observations=pending_observations,
fixed_features=fixed_features_from_node,
# only pass in the most recent generator run to avoid unnecessary
# deduplication in extend_pending_observations
generator_runs=[grs[-1]],
)
)
# ensure that the points generated from each node are marked as pending
# points for future calls to gen
pending_observations = extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
# only pass in the most recent generator run to avoid unnecessary
# deduplication in extend_pending_observations
generator_runs=[grs[-1]],
)
continue_gen_for_trial = self._should_continue_gen_for_trial()
return grs

Expand Down Expand Up @@ -1044,7 +1054,8 @@ def _fit_current_model(self, data: Data | None) -> None:
self._model = self._curr._fitted_model

def _maybe_transition_to_next_node(
self, raise_data_required_error: bool = True
self,
raise_data_required_error: bool = True,
) -> bool:
"""Moves this generation strategy to next node if the current node is completed,
and it is not the last node in this generation strategy. This method is safe to
Expand Down
17 changes: 10 additions & 7 deletions ax/modelbridge/tests/test_generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,6 @@ def test_consume_all_n_constructor(self) -> None:

def test_repeat_arm_n_constructor(self) -> None:
"""Test that the repeat_arm_n_constructor returns a small percentage of n."""
small_n = NodeInputConstructors.REPEAT_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 5},
experiment=self.experiment,
)
medium_n = NodeInputConstructors.REPEAT_N(
previous_node=None,
next_node=self.sobol_generation_node,
Expand All @@ -79,10 +73,19 @@ def test_repeat_arm_n_constructor(self) -> None:
gs_gen_call_kwargs={"n": 11},
experiment=self.experiment,
)
self.assertEqual(small_n, 0)
self.assertEqual(medium_n, 1)
self.assertEqual(large_n, 2)

def test_repeat_arm_n_constructor_return_0(self) -> None:
small_n = NodeInputConstructors.REPEAT_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 5},
experiment=self.experiment,
)
self.assertEqual(small_n, 0)
self.assertTrue(self.sobol_generation_node._should_skip)

def test_remaining_n_constructor_expect_1(self) -> None:
"""Test that the remaining_n_constructor returns the remaining n."""
# should return 1 because 4 arms already exist and 5 are requested
Expand Down
66 changes: 63 additions & 3 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,61 @@ def test_gs_setup_with_nodes(self) -> None:
logger.output,
)

def test_gs_with_suggested_n_is_zero(self) -> None:
"""Test that the number of arms from a node is zero if the node is not
active.
"""
exp = get_branin_experiment()
node_2 = GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGen(
transition_to="sobol_3", continue_trial_generation=True
)
],
input_constructors={
InputConstructorPurpose.N: NodeInputConstructors.REPEAT_N
},
)
gs = GenerationStrategy(
nodes=[
node_2,
GenerationNode(
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGen(
transition_to="sobol_2",
block_transition_if_unmet=True,
continue_trial_generation=False,
),
],
input_constructors={
InputConstructorPurpose.N: NodeInputConstructors.REMAINING_N
},
),
]
)
# First check that we can generate multiple times with a skipped node in
# in a cyclic gs dag
for _i in range(3):
# if you request < 6 arms, repeat arm input constructor will return 0 arms
grs = gs.gen_with_multiple_nodes(experiment=exp, n=5)
self.assertEqual(len(grs), 1) # only generated from one node
self.assertEqual(grs[0]._generation_node_name, "sobol_3")
self.assertEqual(len(grs[0].arms), 5) # all 5 arms from sobol 3
self.assertTrue(node_2._should_skip)

# Now validate that we can get grs from sobol_2 if we request enough n
grs = gs.gen_with_multiple_nodes(experiment=exp, n=8)
self.assertEqual(len(grs), 2)
self.assertEqual(grs[0]._generation_node_name, "sobol_2")
self.assertEqual(len(grs[0].arms), 1)
self.assertEqual(grs[1]._generation_node_name, "sobol_3")
self.assertEqual(len(grs[1].arms), 7)
self.assertFalse(node_2._should_skip)

def test_gen_with_multiple_nodes_pending_points(self) -> None:
exp = get_experiment_with_multi_objective()
gs = GenerationStrategy(
Expand Down Expand Up @@ -1380,7 +1435,7 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None:
# check first call is 6 (from the previous trial having 6 arms)
self.assertEqual(len(list(pending_in_each_gen)[0][1]["m1"]), 6)

def test_gs_initializes_all_previous_node_to_none(self) -> None:
def test_gs_initializes_default_props_correctly(self) -> None:
"""Test that all previous nodes are initialized to None"""
node_1 = GenerationNode(
node_name="node_1",
Expand All @@ -1401,13 +1456,18 @@ def test_gs_initializes_all_previous_node_to_none(self) -> None:
node_3,
],
)
with self.subTest("after initialization all should be none"):
with self.subTest("after initialization all previous nodes should be none"):
for node in gs._nodes:
self.assertIsNone(node._previous_node_name)
self.assertIsNone(node.previous_node)
with self.subTest("check previous node nodes after being set"):
with self.subTest("check previous node after it is set"):
gs._nodes[1]._previous_node_name = "node_1"
self.assertEqual(gs._nodes[1].previous_node, node_1)
with self.subTest(
"after initialization all nodes should have should_skip set to False"
):
for node in gs._nodes:
self.assertFalse(node._should_skip)

def test_gs_with_generation_nodes(self) -> None:
"Simple test of a SOBOL + MBM GenerationStrategy composed of GenerationNodes"
Expand Down
Loading