Skip to content

Commit e8cbbae

Browse files
esantorellafacebook-github-bot
authored andcommitted
make qNIPV not an AnalyticAcquisitionFunction; optimize_acqf support clarity (#2286)
Summary: Pull Request resolved: #2286 I put a `LogExpectedImprovement` instance into `optimize_acqf`, and when I got an error about it not having an attribute `X_pending`, I was not sure if this was a bug or if I did something known to be unsupported. - Make `qNegIntegratedPosteriorVariance` inherit from `AcquisitionFunction` rather than `AnalyticAcquisitionFunction`, because the functionality it was inheriting from `AnalyticAcquisitionFunction` was not relevant. - `qNegIntegratedPosteriorVariance` loses an error message about not supporting multi-output with a `PosteriorTransform` that is not scalarized and gains a unit test showing that it does. Reviewed By: Balandat Differential Revision: D55843171 fbshipit-source-id: 87cbc84783c61e09f5f8dce935422b429ad8699d
1 parent 968e465 commit e8cbbae

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

botorch/acquisition/active_learning.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import torch
2929
from botorch import settings
30-
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
30+
from botorch.acquisition.acquisition import AcquisitionFunction
3131
from botorch.acquisition.monte_carlo import MCAcquisitionFunction
3232
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
3333
from botorch.models.model import Model
@@ -37,7 +37,7 @@
3737
from torch import Tensor
3838

3939

40-
class qNegIntegratedPosteriorVariance(AnalyticAcquisitionFunction):
40+
class qNegIntegratedPosteriorVariance(AcquisitionFunction):
4141
r"""Batch Integrated Negative Posterior Variance for Active Learning.
4242
4343
This acquisition function quantifies the (negative) integrated posterior variance
@@ -75,7 +75,8 @@ def __init__(
7575
points that have been submitted for function evaluation but
7676
have not yet been evaluated.
7777
"""
78-
super().__init__(model=model, posterior_transform=posterior_transform)
78+
super().__init__(model=model)
79+
self.posterior_transform = posterior_transform
7980
if sampler is None:
8081
# If no sampler is provided, we use the following dummy sampler for the
8182
# fantasize() method in forward. IMPORTANT: This assumes that the posterior

botorch/optim/optimize.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def _raise_deprecation_warning_if_kwargs(fn_name: str, kwargs: Dict[str, Any]) -
153153
f"`{fn_name}` does not support arguments {list(kwargs.keys())}. In "
154154
"the future, this will become an error.",
155155
DeprecationWarning,
156+
stacklevel=2,
156157
)
157158

158159

@@ -366,7 +367,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
366367
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
367368
"set of initial conditions."
368369
)
369-
warnings.warn(first_warn_msg, RuntimeWarning)
370+
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)
370371

371372
if not initial_conditions_provided:
372373
batch_initial_conditions = opt_inputs.get_ic_generator()(
@@ -392,6 +393,7 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
392393
"Optimization failed on the second try, after generating a "
393394
"new set of initial conditions.",
394395
RuntimeWarning,
396+
stacklevel=2,
395397
)
396398

397399
if opt_inputs.post_processing_func is not None:

test/acquisition/test_active_learning.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -77,31 +77,37 @@ def test_q_neg_int_post_variance(self):
7777
val = qNIPV(X)
7878
val_exp = -variance.mean(dim=-2).squeeze(-1)
7979
self.assertAllClose(val, val_exp, atol=1e-4)
80+
8081
# multi-output model
8182
mean = torch.zeros(4, 2, device=self.device, dtype=dtype)
8283
variance = torch.rand(4, 2, device=self.device, dtype=dtype)
8384
cov = torch.diag_embed(variance.view(-1))
8485
f_posterior = GPyTorchPosterior(MultitaskMultivariateNormal(mean, cov))
8586
mc_points = torch.rand(10, 1, device=self.device, dtype=dtype)
8687
mfm = MockModel(f_posterior)
87-
with mock.patch.object(MockModel, "fantasize", return_value=mfm):
88-
with mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
89-
mock_num_outputs.return_value = 2
90-
mm = MockModel(None)
88+
with mock.patch.object(
89+
MockModel, "fantasize", return_value=mfm
90+
), mock.patch(no, new_callable=mock.PropertyMock) as mock_num_outputs:
91+
mock_num_outputs.return_value = 2
92+
mm = MockModel(None)
93+
94+
weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype)
95+
qNIPV = qNegIntegratedPosteriorVariance(
96+
model=mm,
97+
mc_points=mc_points,
98+
posterior_transform=ScalarizedPosteriorTransform(weights=weights),
99+
)
100+
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
101+
val = qNIPV(X)
102+
self.assertAllClose(val, -0.5 * variance.mean(), atol=1e-4)
103+
# without posterior_transform
104+
qNIPV = qNegIntegratedPosteriorVariance(
105+
model=mm,
106+
mc_points=mc_points,
107+
)
108+
val = qNIPV(X)
109+
self.assertAllClose(val, -variance.mean(0), atol=1e-4)
91110

92-
weights = torch.tensor([0.5, 0.5], device=self.device, dtype=dtype)
93-
qNIPV = qNegIntegratedPosteriorVariance(
94-
model=mm,
95-
mc_points=mc_points,
96-
posterior_transform=ScalarizedPosteriorTransform(
97-
weights=weights
98-
),
99-
)
100-
X = torch.empty(1, 1, device=self.device, dtype=dtype) # dummy
101-
val = qNIPV(X)
102-
self.assertTrue(
103-
torch.allclose(val, -0.5 * variance.mean(), atol=1e-4)
104-
)
105111
# batched multi-output model
106112
mean = torch.zeros(4, 3, 1, 2, device=self.device, dtype=dtype)
107113
variance = torch.rand(4, 3, 1, 2, device=self.device, dtype=dtype)

0 commit comments

Comments
 (0)