-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
acquisition function wrapper (#1532)
Summary: Pull Request resolved: #1532 Add a wrapper for modifying inputs/outputs. This is useful for not only probabilistic reparameterization, but will also simplify other integrated AFs (e.g. MCMC) as well as fixed feature AFs and things like prior-guided AFs Differential Revision: D41629186 fbshipit-source-id: f27ddc07ba0dd1dc5eb5e91f8f1d00ebd55df49f
- Loading branch information
1 parent
63dd0cd
commit 98fc27f
Showing
8 changed files
with
144 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
r""" | ||
A wrapper classes around AcquisitionFunctions to modify inputs and outputs. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
from botorch.acquisition.acquisition import AcquisitionFunction | ||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
|
||
class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC): | ||
r"""Abstract acquisition wrapper.""" | ||
|
||
def __init__(self, acq_function: AcquisitionFunction) -> None: | ||
Module.__init__(self) | ||
self.acq_func = acq_function | ||
|
||
@property | ||
def X_pending(self) -> Optional[Tensor]: | ||
r"""Return the `X_pending` of the base acquisition function.""" | ||
try: | ||
return self.acq_func.X_pending | ||
except (ValueError, AttributeError): | ||
raise ValueError( | ||
f"Base acquisition function {type(self.acq_func).__name__} " | ||
"does not have an `X_pending` attribute." | ||
) | ||
|
||
def set_X_pending(self, X_pending: Optional[Tensor]) -> None: | ||
r"""Sets the `X_pending` of the base acquisition function.""" | ||
self.acq_func.set_X_pending(X_pending) | ||
|
||
@abstractmethod | ||
def forward(self, X: Tensor) -> Tensor: | ||
r"""Evaluate the wrapped acquisition function on the candidate set X. | ||
Args: | ||
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim | ||
design points each. | ||
Returns: | ||
A `(b)`-dim Tensor of acquisition function values at the given | ||
design points `X`. | ||
""" | ||
pass # pragma: no cover |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from botorch.acquisition.analytic import ExpectedImprovement | ||
from botorch.acquisition.monte_carlo import qExpectedImprovement | ||
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper | ||
from botorch.exceptions.errors import UnsupportedError | ||
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior | ||
|
||
|
||
class DummyWrapper(AbstractAcquisitionFunctionWrapper): | ||
def forward(self, X): | ||
return self.acq_func(X) | ||
|
||
|
||
class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase): | ||
def test_abstract_acquisition_function_wrapper(self): | ||
for dtype in (torch.float, torch.double): | ||
mm = MockModel( | ||
MockPosterior( | ||
mean=torch.rand(1, 1, dtype=dtype, device=self.device), | ||
variance=torch.ones(1, 1, dtype=dtype, device=self.device), | ||
) | ||
) | ||
acq_func = ExpectedImprovement(model=mm, best_f=-1.0) | ||
wrapped_af = DummyWrapper(acq_function=acq_func) | ||
self.assertIs(wrapped_af.acq_func, acq_func) | ||
# test forward | ||
X = torch.rand(1, 1, dtype=dtype, device=self.device) | ||
with torch.no_grad(): | ||
wrapped_val = wrapped_af(X) | ||
af_val = acq_func(X) | ||
self.assertEqual(wrapped_val.item(), af_val.item()) | ||
|
||
# test X_pending | ||
with self.assertRaises(ValueError): | ||
self.assertIsNone(wrapped_af.X_pending) | ||
with self.assertRaises(UnsupportedError): | ||
wrapped_af.set_X_pending(X) | ||
acq_func = qExpectedImprovement(model=mm, best_f=-1.0) | ||
wrapped_af = DummyWrapper(acq_function=acq_func) | ||
self.assertIsNone(wrapped_af.X_pending) | ||
wrapped_af.set_X_pending(X) | ||
self.assertTrue(torch.equal(X, wrapped_af.X_pending)) | ||
self.assertTrue(torch.equal(X, acq_func.X_pending)) | ||
wrapped_af.set_X_pending(None) | ||
self.assertIsNone(wrapped_af.X_pending) | ||
self.assertIsNone(acq_func.X_pending) |