-
Notifications
You must be signed in to change notification settings - Fork 409
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ThompsonSampling acquisition function (#2443)
Summary: Pull Request resolved: #2443 Thompson sampling (approx with RFF & pathwise) as an acquisition function to have it fit with general BO loops (&MBM, although secondary ATM). Amend: Removed Fully Bayesian variant, since it did not make sense in its current format. Reviewed By: saitcakmak Differential Revision: D59961584 fbshipit-source-id: fccb53c0aa5ac990bed1de57b57c2d7719cd6581
- Loading branch information
1 parent
96a71e7
commit 4497a5c
Showing
3 changed files
with
228 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from botorch.acquisition.analytic import AcquisitionFunction | ||
from botorch.acquisition.objective import PosteriorTransform | ||
from botorch.models.model import Model | ||
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model | ||
from botorch.utils.transforms import t_batch_mode_transform | ||
from torch import Tensor | ||
|
||
|
||
BATCH_SIZE_CHANGE_ERROR = """The batch size of PathwiseThompsonSampling should \ | ||
not change during a forward pass - was {}, now {}. Please re-initialize the \ | ||
acquisition if you want to change the batch size.""" | ||
|
||
|
||
class PathwiseThompsonSampling(AcquisitionFunction): | ||
r"""Single-outcome Thompson Sampling packaged as an (analytic) | ||
acquisition function. Querying the acquisition function gives the summed | ||
values of one or more draws from a pathwise drawn posterior sample, and thus | ||
it maximization yields one (or multiple) Thompson sample(s). | ||
Example: | ||
>>> model = SingleTaskGP(train_X, train_Y) | ||
>>> TS = PathwiseThompsonSampling(model) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Model, | ||
posterior_transform: Optional[PosteriorTransform] = None, | ||
) -> None: | ||
r"""Single-outcome TS. | ||
Args: | ||
model: A fitted GP model. | ||
posterior_transform: A PosteriorTransform. If using a multi-output model, | ||
a PosteriorTransform that transforms the multi-output posterior into a | ||
single-output posterior is required. | ||
""" | ||
if model._is_fully_bayesian: | ||
raise NotImplementedError( | ||
"PathwiseThompsonSampling is not supported for fully Bayesian models", | ||
) | ||
|
||
super().__init__(model=model) | ||
self.batch_size: Optional[int] = None | ||
|
||
def redraw(self) -> None: | ||
self.samples = get_matheron_path_model( | ||
model=self.model, sample_shape=torch.Size([self.batch_size]) | ||
) | ||
|
||
@t_batch_mode_transform() | ||
def forward(self, X: Tensor) -> Tensor: | ||
r"""Evaluate the pathwise posterior sample draws on the candidate set X. | ||
Args: | ||
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points. | ||
Returns: | ||
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of | ||
evaluations on the posterior sample draws. | ||
""" | ||
batch_size = X.shape[-2] | ||
q_dim = -2 | ||
|
||
# batch_shape x q x 1 x d | ||
X = X.unsqueeze(-2) | ||
if self.batch_size is None: | ||
self.batch_size = batch_size | ||
self.redraw() | ||
elif self.batch_size != batch_size: | ||
raise ValueError( | ||
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size) | ||
) | ||
|
||
# posterior_values.shape post-squeeze: | ||
# batch_shape x q x m | ||
posterior_values = self.samples(X).squeeze(-2) | ||
# sum over batch dim and squeeze num_objectives dim (-1) | ||
return posterior_values.sum(q_dim).squeeze(-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from itertools import product | ||
|
||
import torch | ||
from botorch.acquisition.thompson_sampling import PathwiseThompsonSampling | ||
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP | ||
|
||
from botorch.models.gp_regression import SingleTaskGP | ||
from botorch.models.model import Model | ||
from botorch.models.transforms.outcome import Standardize | ||
from botorch.utils.testing import BotorchTestCase | ||
|
||
|
||
def get_model(train_X, train_Y, standardize_model): | ||
if standardize_model: | ||
outcome_transform = Standardize(m=1) | ||
|
||
else: | ||
outcome_transform = None | ||
model = SingleTaskGP( | ||
train_X=train_X, | ||
train_Y=train_Y, | ||
outcome_transform=outcome_transform, | ||
) | ||
return model | ||
|
||
|
||
def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs): | ||
|
||
mcmc_samples = { | ||
"lengthscale": torch.rand(num_samples, 1, dim, **tkwargs), | ||
"outputscale": torch.rand(num_samples, **tkwargs), | ||
"mean": torch.randn(num_samples, **tkwargs), | ||
} | ||
if infer_noise: | ||
mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs) | ||
return mcmc_samples | ||
|
||
|
||
def get_fully_bayesian_model( | ||
train_X, | ||
train_Y, | ||
num_models, | ||
**tkwargs, | ||
): | ||
|
||
model = SaasFullyBayesianSingleTaskGP( | ||
train_X=train_X, | ||
train_Y=train_Y, | ||
) | ||
mcmc_samples = _get_mcmc_samples( | ||
num_samples=num_models, | ||
dim=train_X.shape[-1], | ||
infer_noise=True, | ||
**tkwargs, | ||
) | ||
model.load_mcmc_samples(mcmc_samples) | ||
return model | ||
|
||
|
||
class TestPathwiseThompsonSampling(BotorchTestCase): | ||
def _test_thompson_sampling_base(self, model: Model): | ||
acq = PathwiseThompsonSampling( | ||
model=model, | ||
) | ||
X_observed = model.train_inputs[0] | ||
input_dim = X_observed.shape[-1] | ||
test_X = torch.rand(4, 1, input_dim).to(X_observed) | ||
# re-draw samples and expect other output | ||
acq_pass = acq(test_X) | ||
self.assertTrue(acq_pass.shape == test_X.shape[:-2]) | ||
|
||
acq_pass1 = acq(test_X) | ||
self.assertAllClose(acq_pass1, acq(test_X)) | ||
acq.redraw() | ||
acq_pass2 = acq(test_X) | ||
self.assertFalse(torch.allclose(acq_pass1, acq_pass2)) | ||
|
||
def _test_thompson_sampling_batch(self, model: Model): | ||
X_observed = model.train_inputs[0] | ||
input_dim = X_observed.shape[-1] | ||
batch_acq = PathwiseThompsonSampling( | ||
model=model, | ||
) | ||
self.assertEqual(batch_acq.batch_size, None) | ||
test_X = torch.rand(4, 5, input_dim).to(X_observed) | ||
batch_acq(test_X) | ||
self.assertEqual(batch_acq.batch_size, 5) | ||
test_X = torch.rand(4, 7, input_dim).to(X_observed) | ||
with self.assertRaisesRegex( | ||
ValueError, | ||
"The batch size of PathwiseThompsonSampling should not " | ||
"change during a forward pass - was 5, now 7. Please re-initialize " | ||
"the acquisition if you want to change the batch size.", | ||
): | ||
batch_acq(test_X) | ||
|
||
batch_acq2 = PathwiseThompsonSampling(model) | ||
test_X = torch.rand(4, 7, 1, input_dim).to(X_observed) | ||
self.assertEqual(batch_acq2(test_X).shape, test_X.shape[:-2]) | ||
|
||
batch_acq3 = PathwiseThompsonSampling(model) | ||
test_X = torch.rand(4, 7, 3, input_dim).to(X_observed) | ||
self.assertEqual(batch_acq3(test_X).shape, test_X.shape[:-2]) | ||
|
||
def test_thompson_sampling_single_task(self): | ||
input_dim = 2 | ||
num_objectives = 1 | ||
for dtype, standardize_model in product( | ||
(torch.float32, torch.float64), (True, False) | ||
): | ||
tkwargs = {"device": self.device, "dtype": dtype} | ||
train_X = torch.rand(4, input_dim, **tkwargs) | ||
train_Y = 10 * torch.rand(4, num_objectives, **tkwargs) | ||
model = get_model(train_X, train_Y, standardize_model=standardize_model) | ||
self._test_thompson_sampling_base(model) | ||
self._test_thompson_sampling_batch(model) | ||
|
||
def test_thompson_sampling_fully_bayesian(self): | ||
input_dim = 2 | ||
num_objectives = 1 | ||
tkwargs = {"device": self.device, "dtype": torch.float64} | ||
train_X = torch.rand(4, input_dim, **tkwargs) | ||
train_Y = 10 * torch.rand(4, num_objectives, **tkwargs) | ||
|
||
fb_model = get_fully_bayesian_model(train_X, train_Y, num_models=3, **tkwargs) | ||
with self.assertRaisesRegex( | ||
NotImplementedError, | ||
"PathwiseThompsonSampling is not supported for fully Bayesian models", | ||
): | ||
PathwiseThompsonSampling(model=fb_model) |