Skip to content

Commit

Permalink
ReductionCriterion to inherit from FuncEnum` (which will help it pl…
Browse files Browse the repository at this point in the history
…ay well with serialization and storage) (facebook#2915)

Summary:

As titled, applying the new `FuncEnum` util

Reviewed By: mgarrard

Differential Revision: D63914987
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Oct 24, 2024
1 parent 322ea96 commit f6e35d1
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
46 changes: 31 additions & 15 deletions ax/modelbridge/best_model_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from functools import partial
from enum import unique
from typing import Any, Union

import numpy as np
from ax.exceptions.core import UserInputError
from ax.modelbridge.model_spec import ModelSpec
from ax.utils.common.base import Base
from ax.utils.common.func_enum import FuncEnum
from ax.utils.common.typeutils import not_none

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Expand All @@ -36,27 +35,24 @@ def best_model(self, model_specs: list[ModelSpec]) -> ModelSpec:
"""


class ReductionCriterion(Enum):
@unique
class ReductionCriterion(FuncEnum):
"""An enum for callables that are used for aggregating diagnostics over metrics
and selecting the best diagnostic in ``SingleDiagnosticBestModelSelector``.
NOTE: The methods defined by this enum should all share identical signatures:
``Callable[[ARRAYLIKE], np.ndarray]``, and reside in this file.
NOTE: This is used to ensure serializability of the callables.
"""

# NOTE: Callables need to be wrapped in `partial` to be registered as members.
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MEAN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.mean)
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MIN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.min)
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MAX: Callable[[ARRAYLIKE], np.ndarray] = partial(np.max)
MEAN = "mean_reduction_criterion"
MIN = "min_reduction_criterion"
MAX = "max_reduction_criterion"

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def __call__(self, array_like: ARRAYLIKE) -> np.ndarray:
return self.value(array_like)
return super().__call__(array_like=array_like)


class SingleDiagnosticBestModelSelector(BestModelSelector):
Expand Down Expand Up @@ -132,3 +128,23 @@ def best_model(self, model_specs: list[ModelSpec]) -> ModelSpec:
best_diagnostic = self.criterion(aggregated_diagnostic_values).item()
best_index = aggregated_diagnostic_values.index(best_diagnostic)
return model_specs[best_index]


# ------------------------- Reduction criteria ------------------------- #


# Wrap the numpy functions, to be able to access them directly from this
# module in `ReductionCriterion(FuncEnum)` and to have typechecking
def mean_reduction_criterion(array_like: ARRAYLIKE) -> np.ndarray:
"""Compute the mean of an array-like object."""
return np.mean(array_like)


def min_reduction_criterion(array_like: ARRAYLIKE) -> np.ndarray:
"""Compute the min of an array-like object."""
return np.min(array_like)


def max_reduction_criterion(array_like: ARRAYLIKE) -> np.ndarray:
"""Compute the max of an array-like object."""
return np.max(array_like)
29 changes: 29 additions & 0 deletions ax/modelbridge/tests/test_best_model_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

# pyre-strict

import inspect
from unittest.mock import Mock, patch

import numpy as np

from ax.exceptions.core import UserInputError
from ax.modelbridge.best_model_selector import (
ReductionCriterion,
Expand Down Expand Up @@ -36,6 +39,32 @@ def setUp(self) -> None:
ms._last_cv_kwargs = {}
self.model_specs.append(ms)

def test_member_typing(self) -> None:
for reduction_criterion in ReductionCriterion:
signature = inspect.signature(reduction_criterion._get_function_for_value())
self.assertEqual(signature.return_annotation, "np.ndarray")

# pyre-fixme [56]: Pyre was not able to infer the type of argument
# `numpy` to decorator factory `unittest.mock.patch`
@patch(f"{ReductionCriterion.__module__}.np", wraps=np)
def test_ReductionCriterion(self, mock_np: Mock) -> None:
untested_reduction_criteria = set(ReductionCriterion)
# Check MEAN (should just fall through to `np.mean`)
array = np.array([1, 2, 3]) # and then use this var all the way down
self.assertEqual(ReductionCriterion.MEAN(array), np.mean(array))
mock_np.mean.assert_called_once()
untested_reduction_criteria.remove(ReductionCriterion.MEAN)
# Check MIN (should just fall through to `np.min`)
self.assertEqual(ReductionCriterion.MIN(np.array([1, 2, 3])), 1.0)
mock_np.min.assert_called_once()
untested_reduction_criteria.remove(ReductionCriterion.MIN)
# Check MAX (should just fall through to `np.max`)
self.assertEqual(ReductionCriterion.MAX(np.array([1, 2, 3])), 3.0)
mock_np.max.assert_called_once()
untested_reduction_criteria.remove(ReductionCriterion.MAX)
# There should be no untested reduction criteria left
self.assertEqual(len(untested_reduction_criteria), 0)

def test_user_input_error(self) -> None:
with self.assertRaisesRegex(UserInputError, "ReductionCriterion"):
SingleDiagnosticBestModelSelector(
Expand Down
3 changes: 3 additions & 0 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class TestGenerationStrategyWithoutModelBridgeMocks(TestCase):
test class that makes use of mocking rather sparingly.
"""

def _setUp(self) -> None:
super().setUp()

@fast_botorch_optimize
@patch(
"ax.modelbridge.generation_node._extract_model_state_after_gen",
Expand Down

0 comments on commit f6e35d1

Please sign in to comment.