-
Notifications
You must be signed in to change notification settings - Fork 409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vizier output transforms #2643
base: main
Are you sure you want to change the base?
Vizier output transforms #2643
Conversation
…tant bound adjustment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these. I've been curious to try out the Vizier transforms. These transforms try to ensure we model the good regions better, which seems better than trying to model everything equally well, when the goal is to optimize a function. Have you tried these on any benchmarks yet? It'd be good to see how they work e2e.
None of these transforms implement the untransform_posterior
method, which is necessary to make sure the transforms work as an outcome_transform
in BoTorch models. Without it, Model.posterior
will error out. Would it be possible to implement untransform_posterior
, likely returning a TransformedPosterior
object.
None of the transforms seem to support Yvar
either. We use BoTorch through Ax and try to keep the same methods regardless of whether the noise is observed. What does Vizier do with observation noise? Is it simply ignored, or is this just a limitation of current implementation?
assert transform._batch_shape == batch_shape | ||
assert not transform._is_trained | ||
assert transform._shift is None | ||
assert torch.isnan(transform.warped_bad_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We tend to use self.assertEqual(...), self.assertTrue/False(...)
etc from TestCase
rather than using bare assert
in tests
botorch/models/transforms/outcome.py
Outdated
if Yvar is not None: | ||
raise NotImplementedError( | ||
"InfeasibleTransform does not support transforming observation noise" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be great to support Yvar
as well. Could just set it to the mean of all Yvar
s? That would only minimally affect the model predictions and let the code run.
y = Y_transformed[*batch_idx, :, dim] | ||
|
||
# Get finite values and their ranks for each batch | ||
is_finite_mask = ~torch.isnan(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BoTorch models do not support NaN inputs. If we didn't need to deal with NaN and somehow avoided unique
, could we avoid the double for loop here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the salient point, for me which indicates that I should probably have done this in Ax. Would Ax error out if a metric returned NaN before applying transforms or only after?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the error would come after, though I would have to double check. IIRC, the NaN will propagate down to BoTorch and raise an error there.
I am planning a major refactor of data handling within Ax ModelBridge layer & transforms. The goal is to reduce looping to make it more efficient & scalable. Though if you want to add these in Ax now, it doesn't have to be blocked on the refactor.
botorch/models/transforms/outcome.py
Outdated
for i, val in enumerate(unique_y): | ||
ranks[y == val] = i + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we didn't care about uniqueness, we could get the rank / order using torch.argsort
. Imo, the duplicate observations should be pretty rare in practice and probably are safe to ignore. They're probably mostly introduced by the infeasibility warper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could also find the median using the middle index if we simply sorted Y
from the start.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth having this generality incase other packages like BoFire
want to make use of these methods, I tried to stick as close to the reference implementation as possible but expanding to handle multi-dimension and batch_size
|
||
# Process values below median | ||
below_median = y < self._original_label_medians[*batch_idx, dim] | ||
if below_median.any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite a complicated process to untransform. I find it quite difficult to review since I don't know what the algorithm is supposed to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The HalfRankTransform splits the data at the median and then warps the bottom half onto a gaussian whose std matches the std of the top half of the distribution. On inverting there are 3 cases, a) we can just look it up as it's very close to a known value from the fit, b) we take a in support linear interpolation, c) outside of the support we take a linear extrapolation based on the entire data range.
I implemented these here before looking into what I would need to do to connect them to my benchmark testcase that I have been prototyping in Ax. I assumed that the best case would be to implement it lower in the stack and then bubble up but seemingly I should have implemented at the level of Ax (especially as the harder part here was handling the
This should be able to be done based on the logic in
I think the vizier approach is similar to performing the transformations at the layer of Ax. Their equivalent to Metric only deals with a single float observation. As such and from the docs I believe it's ignored. I don't think that we can implement Yvar for either the LogWarping or the HalfRank transform |
This is something I've also been thinking about quite a bit lately, as part of reworking what Ax uses by default (though this has been rather focused on input transforms so far). Outcome transforms may be slightly simpler in Ax. You technically only need to be able to transform, though we also want some rough untransform method for reporting & cross validation purposes. In BoTorch, you need the transform to be differentiable, since the untransformed posterior is used to compute the acquisition value which needs to be differentiated through during optimization. Ax transforms are applied once as a pre-processing step but BoTorch transforms are applied at each If I were to implement Vizier transforms, I'd probably try them in Ax, since there are fewer restrictions.
Do you often use batched models? If not, you don't need to support
Agreed. It is more of a pre-processing step.
Cool, this is good to know. |
Will move forward with a PR at Ax as that seems like the faster way to make progress on my usecase |
Sounds good. StandardizeY or PowerTransformY should be good examples to follow |
Motivation
I wanted to experiment with the vizier output transforms in botorch - in particular the infeasible transform. Potentially misplaced effort as it appears that within the meta ecosystem Ax actually handles this sort of stuff and so these may not actually end up being able to used very much but worth upstreaming.
takes transforms from: https://arxiv.org/abs/2408.11527
Test Plan
Added tests for all functionality, not sure I am at 100% coverage
Related PRs
this branch takes off from #2636 so that would need to be merged first.