-
Notifications
You must be signed in to change notification settings - Fork 313
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add fully discrete, unordered benchmark problem (#3211)
Summary: Pull Request resolved: #3211 This PR: * Moves `IdentityTestFunction` from test stubs to new file `benchmark/test_functions/synthetic.py`. * Adds `get_bandit_problem`, which uses a discrete search space on values 0, 1, ...9, so it can produce ground-truth metric values 0, 1, ... 9. It uses the inference trace (based on model-recommended best point) for scoring. Reviewed By: Balandat Differential Revision: D67535995 fbshipit-source-id: ee4a62f6db144f8b83b29dd6d56bf29ab422327c
- Loading branch information
1 parent
ef1af42
commit c561e7a
Showing
10 changed files
with
177 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from collections.abc import Mapping, Sequence | ||
from dataclasses import dataclass, field | ||
|
||
import torch | ||
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class IdentityTestFunction(BenchmarkTestFunction): | ||
""" | ||
Test function that returns the value of parameter "x0", ignoring any others. | ||
""" | ||
|
||
outcome_names: Sequence[str] = field(default_factory=lambda: ["objective"]) | ||
n_steps: int = 1 | ||
|
||
# pyre-fixme[14]: Inconsistent override | ||
def evaluate_true(self, params: Mapping[str, float]) -> torch.Tensor: | ||
""" | ||
Return params["x0"] for each outcome for each time step. | ||
Args: | ||
params: A dictionary with key "x0". | ||
""" | ||
value = params["x0"] | ||
return torch.full( | ||
(len(self.outcome_names), self.n_steps), value, dtype=torch.float64 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from warnings import warn | ||
|
||
import numpy as np | ||
|
||
from ax.benchmark.benchmark_problem import BenchmarkProblem, get_soo_opt_config | ||
from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction | ||
from ax.core.parameter import ChoiceParameter, ParameterType | ||
from ax.core.search_space import SearchSpace | ||
|
||
|
||
def get_baseline(num_choices: int, n_sims: int = 100000000) -> float: | ||
""" | ||
Compute the baseline value. | ||
The baseline for this problem takes into account noise, because it uses the | ||
inference trace, and the bandit structure, which allows for running all arms | ||
in one noisy batch: | ||
Run a BatchTrial with every arm, with equal size. Choose the arm with the | ||
best observed value and take its true value. Take the expectation of the | ||
outcome of this process. | ||
""" | ||
noise_per_arm = num_choices**0.5 | ||
sim_observed_effects = ( | ||
np.random.normal(0, noise_per_arm, (n_sims, num_choices)) | ||
+ np.arange(num_choices)[None, :] | ||
) | ||
identified_best_arm = sim_observed_effects.argmin(axis=1) | ||
# because of the use of IdentityTestFunction | ||
baseline = identified_best_arm.mean() | ||
return baseline | ||
|
||
|
||
def get_bandit_problem(num_choices: int = 30, num_trials: int = 3) -> BenchmarkProblem: | ||
parameter = ChoiceParameter( | ||
name="x0", | ||
parameter_type=ParameterType.INT, | ||
values=list(range(num_choices)), | ||
is_ordered=False, | ||
sort_values=False, | ||
) | ||
search_space = SearchSpace(parameters=[parameter]) | ||
test_function = IdentityTestFunction() | ||
optimization_config = get_soo_opt_config( | ||
outcome_names=test_function.outcome_names, observe_noise_sd=True | ||
) | ||
baselines = { | ||
10: 1.40736478, | ||
30: 2.4716703, | ||
100: 4.403284, | ||
} | ||
if num_choices not in baselines: | ||
warn( | ||
f"Baseline value is not available for num_choices={num_choices}. Use " | ||
"`get_baseline` to compute the baseline and add it to `baselines`." | ||
) | ||
baseline_value = baselines[30] | ||
else: | ||
baseline_value = baselines[num_choices] | ||
return BenchmarkProblem( | ||
name="Bandit", | ||
num_trials=num_trials, | ||
search_space=search_space, | ||
optimization_config=optimization_config, | ||
optimal_value=0, | ||
baseline_value=baseline_value, | ||
test_function=test_function, | ||
report_inference_value_as_trace=True, | ||
noise_std=1.0, | ||
status_quo_params={"x0": num_choices // 2}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from ax.benchmark.problems.synthetic.bandit import get_bandit_problem, get_baseline | ||
from ax.utils.common.testutils import TestCase | ||
|
||
|
||
class TestProblems(TestCase): | ||
def test_get_baseline(self) -> None: | ||
num_choices = 5 | ||
baseline = get_baseline(num_choices=num_choices, n_sims=100) | ||
self.assertGreater(baseline, 0) | ||
# Worst = num_choices - 1; random guessing = (num_choices - 1) / 2 | ||
self.assertLess(baseline, (num_choices - 1) / 2) | ||
|
||
def test_get_bandit_problem(self) -> None: | ||
problem = get_bandit_problem() | ||
self.assertEqual(problem.name, "Bandit") | ||
self.assertEqual(problem.num_trials, 3) | ||
self.assertTrue(problem.report_inference_value_as_trace) | ||
|
||
problem = get_bandit_problem(num_choices=26, num_trials=4) | ||
self.assertEqual(problem.num_trials, 4) | ||
self.assertEqual(problem.status_quo_params, {"x0": 26 // 2}) | ||
|
||
def test_baseline_exception(self) -> None: | ||
with self.assertWarnsRegex( | ||
Warning, expected_regex="Baseline value is not available for num_choices=17" | ||
): | ||
problem = get_bandit_problem(num_choices=17) | ||
|
||
self.assertEqual(problem.baseline_value, get_bandit_problem().baseline_value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters