Skip to content

Commit

Permalink
Add fully discrete, unordered benchmark problem (#3211)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3211

This PR:
* Moves `IdentityTestFunction` from test stubs to new file `benchmark/test_functions/synthetic.py`.
* Adds `get_bandit_problem`, which uses a discrete search space on values 0, 1, ...9, so it can produce ground-truth metric values 0, 1, ... 9. It uses the inference trace (based on model-recommended best point) for scoring.

Reviewed By: Balandat

Differential Revision: D67535995

fbshipit-source-id: ee4a62f6db144f8b83b29dd6d56bf29ab422327c
  • Loading branch information
esantorella authored and facebook-github-bot committed Dec 24, 2024
1 parent ef1af42 commit c561e7a
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 22 deletions.
36 changes: 36 additions & 0 deletions ax/benchmark/benchmark_test_functions/synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field

import torch
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction


@dataclass(kw_only=True)
class IdentityTestFunction(BenchmarkTestFunction):
"""
Test function that returns the value of parameter "x0", ignoring any others.
"""

outcome_names: Sequence[str] = field(default_factory=lambda: ["objective"])
n_steps: int = 1

# pyre-fixme[14]: Inconsistent override
def evaluate_true(self, params: Mapping[str, float]) -> torch.Tensor:
"""
Return params["x0"] for each outcome for each time step.
Args:
params: A dictionary with key "x0".
"""
value = params["x0"]
return torch.full(
(len(self.outcome_names), self.n_steps), value, dtype=torch.float64
)
4 changes: 4 additions & 0 deletions ax/benchmark/problems/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_pytorch_cnn_torchvision_benchmark_problem,
)
from ax.benchmark.problems.runtime_funcs import int_from_params
from ax.benchmark.problems.synthetic.bandit import get_bandit_problem
from ax.benchmark.problems.synthetic.discretized.mixed_integer import (
get_discrete_ackley,
get_discrete_hartmann,
Expand Down Expand Up @@ -55,6 +56,9 @@ class BenchmarkProblemRegistryEntry:
"name": "ackley4_async_noisy",
},
),
"Bandit": BenchmarkProblemRegistryEntry(
factory_fn=get_bandit_problem, factory_kwargs={}
),
"branin": BenchmarkProblemRegistryEntry(
factory_fn=create_problem_from_botorch,
factory_kwargs={
Expand Down
78 changes: 78 additions & 0 deletions ax/benchmark/problems/synthetic/bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from warnings import warn

import numpy as np

from ax.benchmark.benchmark_problem import BenchmarkProblem, get_soo_opt_config
from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction
from ax.core.parameter import ChoiceParameter, ParameterType
from ax.core.search_space import SearchSpace


def get_baseline(num_choices: int, n_sims: int = 100000000) -> float:
"""
Compute the baseline value.
The baseline for this problem takes into account noise, because it uses the
inference trace, and the bandit structure, which allows for running all arms
in one noisy batch:
Run a BatchTrial with every arm, with equal size. Choose the arm with the
best observed value and take its true value. Take the expectation of the
outcome of this process.
"""
noise_per_arm = num_choices**0.5
sim_observed_effects = (
np.random.normal(0, noise_per_arm, (n_sims, num_choices))
+ np.arange(num_choices)[None, :]
)
identified_best_arm = sim_observed_effects.argmin(axis=1)
# because of the use of IdentityTestFunction
baseline = identified_best_arm.mean()
return baseline


def get_bandit_problem(num_choices: int = 30, num_trials: int = 3) -> BenchmarkProblem:
parameter = ChoiceParameter(
name="x0",
parameter_type=ParameterType.INT,
values=list(range(num_choices)),
is_ordered=False,
sort_values=False,
)
search_space = SearchSpace(parameters=[parameter])
test_function = IdentityTestFunction()
optimization_config = get_soo_opt_config(
outcome_names=test_function.outcome_names, observe_noise_sd=True
)
baselines = {
10: 1.40736478,
30: 2.4716703,
100: 4.403284,
}
if num_choices not in baselines:
warn(
f"Baseline value is not available for num_choices={num_choices}. Use "
"`get_baseline` to compute the baseline and add it to `baselines`."
)
baseline_value = baselines[30]
else:
baseline_value = baselines[num_choices]
return BenchmarkProblem(
name="Bandit",
num_trials=num_trials,
search_space=search_space,
optimization_config=optimization_config,
optimal_value=0,
baseline_value=baseline_value,
test_function=test_function,
report_inference_value_as_trace=True,
noise_std=1.0,
status_quo_params={"x0": num_choices // 2},
)
36 changes: 36 additions & 0 deletions ax/benchmark/tests/problems/synthetic/test_bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from ax.benchmark.problems.synthetic.bandit import get_bandit_problem, get_baseline
from ax.utils.common.testutils import TestCase


class TestProblems(TestCase):
def test_get_baseline(self) -> None:
num_choices = 5
baseline = get_baseline(num_choices=num_choices, n_sims=100)
self.assertGreater(baseline, 0)
# Worst = num_choices - 1; random guessing = (num_choices - 1) / 2
self.assertLess(baseline, (num_choices - 1) / 2)

def test_get_bandit_problem(self) -> None:
problem = get_bandit_problem()
self.assertEqual(problem.name, "Bandit")
self.assertEqual(problem.num_trials, 3)
self.assertTrue(problem.report_inference_value_as_trace)

problem = get_bandit_problem(num_choices=26, num_trials=4)
self.assertEqual(problem.num_trials, 4)
self.assertEqual(problem.status_quo_params, {"x0": 26 // 2})

def test_baseline_exception(self) -> None:
with self.assertWarnsRegex(
Warning, expected_regex="Baseline value is not available for num_choices=17"
):
problem = get_bandit_problem(num_choices=17)

self.assertEqual(problem.baseline_value, get_bandit_problem().baseline_value)
1 change: 1 addition & 0 deletions ax/benchmark/tests/problems/test_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_load_problems(self) -> None:

def test_name(self) -> None:
expected_names = [
("Bandit", "Bandit"),
("branin", "Branin"),
("hartmann3", "Hartmann_3d"),
("hartmann6", "Hartmann_6d"),
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ax.benchmark.benchmark_result import BenchmarkResult
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction
from ax.benchmark.methods.modular_botorch import (
get_sobol_botorch_modular_acquisition,
get_sobol_mbm_generation_strategy,
Expand Down Expand Up @@ -61,7 +62,6 @@
get_multi_objective_benchmark_problem,
get_single_objective_benchmark_problem,
get_soo_surrogate,
IdentityTestFunction,
TestDataset,
)

Expand Down
4 changes: 3 additions & 1 deletion ax/benchmark/tests/test_benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from ax.benchmark.benchmark_runner import _add_noise, BenchmarkRunner
from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction
from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction

from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction
from ax.benchmark.problems.synthetic.hss.jenatton import (
get_jenatton_benchmark_problem,
Jenatton,
Expand All @@ -35,8 +37,8 @@
DummyTestFunction,
get_jenatton_trials,
get_soo_surrogate_test_function,
IdentityTestFunction,
)

from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann
from botorch.utils.transforms import normalize
from pandas import DataFrame
Expand Down
2 changes: 0 additions & 2 deletions ax/modelbridge/tests/test_discrete_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ def test_cross_validate(self, mock_init: Mock) -> None:
def test_get_parameter_values(self) -> None:
parameter_values = _get_parameter_values(self.search_space, ["x", "y", "z"])
self.assertEqual(parameter_values, [[0.0, 1.0], ["foo", "bar"], [True]])
# pyre-fixme[6]: For 1st param expected `List[Parameter]` but got
# `List[Union[ChoiceParameter, FixedParameter]]`.
search_space = SearchSpace(self.parameters)
search_space._parameters["x"] = RangeParameter(
"x", ParameterType.FLOAT, 0.1, 0.4
Expand Down
19 changes: 1 addition & 18 deletions ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

# pyre-strict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from typing import Any, Iterator

Expand All @@ -28,6 +27,7 @@
from ax.benchmark.benchmark_step_runtime_function import TBenchmarkStepRuntimeFunction
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction
from ax.benchmark.benchmark_test_functions.synthetic import IdentityTestFunction
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_search_space
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial
Expand Down Expand Up @@ -301,23 +301,6 @@ def get_next_candidate(
return {self.param_name: next(self.iterator)}


@dataclass(kw_only=True)
class IdentityTestFunction(BenchmarkTestFunction):
outcome_names: Sequence[str] = field(default_factory=lambda: ["objective"])
n_steps: int = 1

# pyre-fixme[14]: Inconsistent override
def evaluate_true(self, params: Mapping[str, float]) -> torch.Tensor:
"""
Args:
params: A dictionary with key "x0".
"""
value = params["x0"]
return torch.full(
(len(self.outcome_names), self.n_steps), value, dtype=torch.float64
)


def get_discrete_search_space(n_values: int = 20) -> SearchSpace:
return SearchSpace(
parameters=[
Expand Down
17 changes: 17 additions & 0 deletions sphinx/source/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ Benchmark Problems Registry
:undoc-members:
:show-inheritance:

Benchmark Problems: Bandit
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.benchmark.problems.synthetic.bandit
:members:
:undoc-members:
:show-inheritance:

Benchmark Problems High Dimensional Embedding
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -146,6 +154,7 @@ Benchmark Problems PyTorchCNN TorchVision
:undoc-members:
:show-inheritance:


Benchmark Problems Runtime Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -169,3 +178,11 @@ Benchmark Test Functions: Surrogate
:members:
:undoc-members:
:show-inheritance:

Benchmark Test Functions: Synthetic
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. automodule:: ax.benchmark.benchmark_test_functions.synthetic
:members:
:undoc-members:
:show-inheritance:

0 comments on commit c561e7a

Please sign in to comment.