-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModelListGP.fantasize
does not support MultiTaskGP
s, but other ModelListGP
methods do
#2398
Comments
In the following, I present a reproducible example for this issue: Versions used:
Reproducible CodeI am using a function to create the groups of targets that are correlated. Its ouput can be seen by the assertion. # imports
import torch
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.mlls import SumMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
from botorch.sampling.normal import SobolQMCNormalSampler
### data creation ###
torch.manual_seed(42)
n = 500
train_X = torch.randn(n, 12)
train_Y1 = torch.randn(n, 1)
train_Y2 = 0.8 * train_Y1 + torch.randn(n, 1) * 0.5
train_Y3 = torch.randn(n, 1)
train_Y = torch.cat([train_Y1, train_Y2, train_Y3], dim=1)
noise_targets = torch.tensor([5e-3, 1e-2, 4e-4])
# computation of correlation groups:
corr_groups = _get_correlated_groups(train_Y, threshold=0.8)
assert corr_groups == [[0,1], [2]], f"Unexpected group creation, got {corr_groups}."
### GP creation ###
models = []
# create models for each correlation group
for group in corr_groups:
if len(group) > 1:
# create MultiTaskGP with num_task for all indices out of corr_group
group_indices = torch.tensor(group)
num_tasks = len(group_indices)
task_indices = torch.arange(num_tasks).repeat(train_X.size(0), 1).view(-1, 1)
train_x_group = torch.cat([train_X.repeat(1, num_tasks).view(-1, train_X.size(-1)), task_indices], dim=-1)
train_y_group = train_Y[:, group_indices].view(-1, 1)
train_yvar = torch.stack([torch.full_like(train_Y[:, i], noise_targets[group_indices[i]]) for i in range(len(group))], 1).view(-1, 1)
model = MultiTaskGP(
train_X=train_x_group,
train_Y=train_y_group,
task_feature=-1,
train_Yvar=train_yvar,
)
models.append(model)
else:
# create SingleTaskGP for targets that have no correlation
ind = group[0]
train_y_group = train_Y[..., ind:ind+1]
noise = noise_targets[ind]
train_yvar = torch.full_like(train_y_group, noise)
model = SingleTaskGP(
train_X=train_X,
train_Y=train_y_group,
train_Yvar=train_yvar,
covar_module=ScaleKernel(RBFKernel()),
)
models.append(model)
# combine all models
overall_model = ModelListGP(*models)
mll = SumMarginalLogLikelihood(overall_model.likelihood, overall_model)
# fit the model using mll
try:
fit_gpytorch_mll(mll, max_attempts=30)
except Exception as e:
raise Exception(f"Fitting did not work. Exception: {e}")
### create fantasize model ###
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]))
fant_model = overall_model.fantasize(train_X, sampler=sampler) Keep in mind:In general, I am not calling the Error:
To me it seems that the Expected BehaviourI would expect the fantasize-model to be created without an error. I hope this code snippet helps to determine the error of my code or the potential bug. |
Discussed in #2396
Originally posted by lucky-luke-98 June 26, 2024
[....]
My idea was to separate the creation of each correlation group and treat them based on having correlation to other targets (->
MultiTaskGP
) or not having correlation to other targets (->SingleTaskGP
).Afterwards, to create one unified model, I would append them in a
ModelListGP
.[....]
following my code, I run into problems with this procedure when performing a Multi-Step-Lookahead (qMultiStepLookahead). Here, the
fantasize
method of theModelListGP
assumes that the number of outputs (self.num_outputs) is equal to the number of models in the ModelList. In my case this assumption is not true.The text was updated successfully, but these errors were encountered: