Skip to content

Commit dd72c8a

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
NoisyExpectedHypervolumeMixin (pytorch#2045)
Summary: X-link: facebook/Ax#1909 This commit introduces `NoisyExpectedHypervolumeMixin`, a derivative of `CachedCholeskyMCSamplerMixin` that separates out much of the Pareto-partitioning required for `qNEHVI`. Reviewed By: Balandat Differential Revision: D50337502
1 parent bb48409 commit dd72c8a

File tree

5 files changed

+399
-284
lines changed

5 files changed

+399
-284
lines changed

botorch/acquisition/multi_objective/monte_carlo.py

+33-277
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
r"""
88
Monte-Carlo Acquisition Functions for Multi-objective Bayesian optimization.
9+
In particular, this module contains implementations of
10+
1) qEHVI [Daulton2020qehvi]_, and
11+
2) qNEHVI [Daulton2021nehvi]_.
912
1013
References
1114
@@ -23,42 +26,27 @@
2326

2427
from __future__ import annotations
2528

26-
import warnings
2729
from abc import ABC, abstractmethod
28-
from copy import deepcopy
2930
from typing import Callable, List, Optional, Union
3031

3132
import torch
3233
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
33-
from botorch.acquisition.cached_cholesky import CachedCholeskyMCSamplerMixin
3434
from botorch.acquisition.multi_objective.objective import (
3535
IdentityMCMultiOutputObjective,
3636
MCMultiOutputObjective,
3737
)
38-
from botorch.acquisition.multi_objective.utils import (
39-
prune_inferior_points_multi_objective,
40-
)
4138
from botorch.exceptions.errors import UnsupportedError
42-
from botorch.exceptions.warnings import BotorchWarning
4339
from botorch.models.model import Model
4440
from botorch.models.transforms.input import InputPerturbation
4541
from botorch.sampling.base import MCSampler
46-
from botorch.utils.multi_objective.box_decompositions.box_decomposition_list import (
47-
BoxDecompositionList,
48-
)
49-
from botorch.utils.multi_objective.box_decompositions.dominated import (
50-
DominatedPartitioning,
51-
)
5242
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
53-
FastNondominatedPartitioning,
5443
NondominatedPartitioning,
5544
)
56-
from botorch.utils.multi_objective.box_decompositions.utils import (
57-
_pad_batch_pareto_frontier,
45+
from botorch.utils.multi_objective.hypervolume import (
46+
NoisyExpectedHypervolumeMixin,
47+
SubsetIndexCachingMixin,
5848
)
59-
from botorch.utils.multi_objective.hypervolume import SubsetIndexCachingMixin
6049
from botorch.utils.objective import compute_smoothed_feasibility_indicator
61-
from botorch.utils.torch import BufferDict
6250
from botorch.utils.transforms import (
6351
concatenate_pending_points,
6452
is_fully_bayesian,
@@ -250,7 +238,9 @@ def _compute_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
250238
q = obj.shape[-2]
251239
if self.constraints is not None:
252240
feas_weights = compute_smoothed_feasibility_indicator(
253-
constraints=self.constraints, samples=samples, eta=self.eta
241+
constraints=self.constraints,
242+
samples=samples,
243+
eta=self.eta,
254244
) # `sample_shape x batch-shape x q`
255245
device = self.ref_point.device
256246
q_subset_indices = self.compute_q_subset_indices(q_out=q, device=device)
@@ -326,7 +316,7 @@ def forward(self, X: Tensor) -> Tensor:
326316

327317

328318
class qNoisyExpectedHypervolumeImprovement(
329-
qExpectedHypervolumeImprovement, CachedCholeskyMCSamplerMixin
319+
NoisyExpectedHypervolumeMixin, qExpectedHypervolumeImprovement
330320
):
331321
def __init__(
332322
self,
@@ -381,10 +371,10 @@ def __init__(
381371
have not yet been evaluated.
382372
eta: The temperature parameter for the sigmoid function used for the
383373
differentiable approximation of the constraints. In case of a float the
384-
same eta is used for every constraint in constraints. In case of a
374+
same `eta` is used for every constraint in constraints. In case of a
385375
tensor the length of the tensor must match the number of provided
386376
constraints. The i-th constraint is then estimated with the i-th
387-
eta value. For more details, on this parameter, see the docs of
377+
`eta` value. For more details, on this parameter, see the docs of
388378
`compute_smoothed_feasibility_indicator`.
389379
prune_baseline: If True, remove points in `X_baseline` that are
390380
highly unlikely to be the pareto optimal and better than the
@@ -407,268 +397,34 @@ def __init__(
407397
`q` points.
408398
cache_root: A boolean indicating whether to cache the root
409399
decomposition over `X_baseline` and use low-rank updates.
400+
marginalize_dim: A batch dimension that should be marginalized. For example,
401+
this is useful when using a batched fully Bayesian model.
410402
"""
411-
if len(ref_point) < 2:
412-
raise ValueError(
413-
"qNoisyExpectedHypervolumeImprovement supports m>=2 outcomes "
414-
f"but ref_point has length {len(ref_point)}, which is smaller than 2."
415-
)
416-
ref_point = torch.as_tensor(
417-
ref_point, dtype=X_baseline.dtype, device=X_baseline.device
418-
)
419-
super(qExpectedHypervolumeImprovement, self).__init__(
403+
MultiObjectiveMCAcquisitionFunction.__init__(
404+
self,
420405
model=model,
421406
sampler=sampler,
422407
objective=objective,
423408
constraints=constraints,
424409
eta=eta,
425410
)
426-
CachedCholeskyMCSamplerMixin.__init__(
427-
self, model=model, cache_root=cache_root, sampler=sampler
428-
)
429-
430-
if X_baseline.ndim > 2:
431-
raise UnsupportedError(
432-
"qNoisyExpectedHypervolumeImprovement does not support batched "
433-
f"X_baseline. Expected 2 dims, got {X_baseline.ndim}."
434-
)
435-
if prune_baseline:
436-
X_baseline = prune_inferior_points_multi_objective(
437-
model=model,
438-
X=X_baseline,
439-
objective=objective,
440-
constraints=constraints,
441-
ref_point=ref_point,
442-
marginalize_dim=marginalize_dim,
443-
)
444-
self.register_buffer("ref_point", ref_point)
445-
self.alpha = alpha
446-
self.q_in = -1
447-
self.q_out = -1
448-
self.q_subset_indices = BufferDict()
449-
self.partitioning = None
450-
# set partitioning class and args
451-
self.p_kwargs = {}
452-
if self.alpha > 0:
453-
self.p_kwargs["alpha"] = self.alpha
454-
self.p_class = NondominatedPartitioning
455-
else:
456-
self.p_class = FastNondominatedPartitioning
457-
self.register_buffer("_X_baseline", X_baseline)
458-
self.register_buffer("_X_baseline_and_pending", X_baseline)
459-
self.register_buffer(
460-
"cache_pending",
461-
torch.tensor(cache_pending, dtype=bool),
462-
)
463-
self.register_buffer(
464-
"_prev_nehvi",
465-
torch.tensor(0.0, dtype=ref_point.dtype, device=ref_point.device),
466-
)
467-
self.register_buffer(
468-
"_max_iep",
469-
torch.tensor(max_iep, dtype=torch.long),
470-
)
471-
self.register_buffer(
472-
"incremental_nehvi",
473-
torch.tensor(incremental_nehvi, dtype=torch.bool),
474-
)
475-
476-
# Base sampler is initialized in _set_cell_bounds.
477-
self.base_sampler = None
478-
479-
if X_pending is not None:
480-
# This will call self._set_cell_bounds if the number of pending
481-
# points is greater than self._max_iep.
482-
self.set_X_pending(X_pending)
483-
# In the case that X_pending is not None, but there are fewer than
484-
# max_iep pending points, the box decompositions are not performed in
485-
# set_X_pending. Therefore, we need to perform a box decomposition over
486-
# f(X_baseline) here.
487-
if X_pending is None or X_pending.shape[-2] <= self._max_iep:
488-
self._set_cell_bounds(num_new_points=X_baseline.shape[0])
489-
# Set q_in=-1 to so that self.sampler is updated at the next forward call.
490-
self.q_in = -1
491-
492-
@property
493-
def X_baseline(self) -> Tensor:
494-
r"""Return X_baseline augmented with pending points cached using CBD."""
495-
return self._X_baseline_and_pending
496-
497-
def _compute_initial_hvs(self, obj: Tensor, feas: Optional[Tensor] = None) -> None:
498-
r"""Compute hypervolume dominated by f(X_baseline) under each sample.
499-
500-
Args:
501-
obj: A `sample_shape x batch_shape x n x m`-dim tensor of samples
502-
of objectives.
503-
feas: `sample_shape x batch_shape x n`-dim tensor of samples
504-
of feasibility indicators.
505-
"""
506-
initial_hvs = []
507-
for i, sample in enumerate(obj):
508-
if self.constraints is not None:
509-
sample = sample[feas[i]]
510-
dominated_partitioning = DominatedPartitioning(
511-
ref_point=self.ref_point,
512-
Y=sample,
513-
)
514-
hv = dominated_partitioning.compute_hypervolume()
515-
initial_hvs.append(hv)
516-
self.register_buffer(
517-
"_initial_hvs",
518-
torch.tensor(initial_hvs, dtype=obj.dtype, device=obj.device).view(
519-
self._batch_sample_shape, *obj.shape[-2:]
520-
),
521-
)
522-
523-
def _set_cell_bounds(self, num_new_points: int) -> None:
524-
r"""Compute the box decomposition under each posterior sample.
525-
526-
Args:
527-
num_new_points: The number of new points (beyond the points
528-
in X_baseline) that were used in the previous box decomposition.
529-
In the first box decomposition, this should be the number of points
530-
in X_baseline.
531-
"""
532-
feas = None
533-
if self.X_baseline.shape[0] > 0:
534-
with torch.no_grad():
535-
posterior = self.model.posterior(self.X_baseline)
536-
# Reset sampler, accounting for possible one-to-many transform.
537-
self.q_in = -1
538-
if self.base_sampler is None:
539-
# Initialize the base sampler if needed.
540-
samples = self.get_posterior_samples(posterior)
541-
self.base_sampler = deepcopy(self.sampler)
542-
else:
543-
samples = self.base_sampler(posterior)
544-
n_w = posterior._extended_shape()[-2] // self.X_baseline.shape[-2]
545-
self._set_sampler(q_in=num_new_points * n_w, posterior=posterior)
546-
# cache posterior
547-
if self._cache_root:
548-
# Note that this implicitly uses LinearOperator's caching to check if
549-
# the proper root decomposition has already been cached to
550-
# `posterior.mvn.lazy_covariance_matrix`, which it may have been in
551-
# the call to `self.base_sampler`, and computes it if not found
552-
self._baseline_L = self._compute_root_decomposition(posterior=posterior)
553-
obj = self.objective(samples, X=self.X_baseline)
554-
if self.constraints is not None:
555-
feas = torch.stack(
556-
[c(samples) <= 0 for c in self.constraints], dim=0
557-
).all(dim=0)
558-
else:
559-
sample_shape = (
560-
self.sampler.sample_shape
561-
if self.sampler is not None
562-
else self._default_sample_shape
563-
)
564-
obj = torch.empty(
565-
*sample_shape,
566-
0,
567-
self.ref_point.shape[-1],
568-
dtype=self.ref_point.dtype,
569-
device=self.ref_point.device,
570-
)
571-
self._batch_sample_shape = obj.shape[:-2]
572-
# collapse batch dimensions
573-
# use numel() rather than view(-1) to handle case of no baseline points
574-
new_batch_shape = self._batch_sample_shape.numel()
575-
obj = obj.view(new_batch_shape, *obj.shape[-2:])
576-
if self.constraints is not None and feas is not None:
577-
feas = feas.view(new_batch_shape, *feas.shape[-1:])
578-
579-
if self.partitioning is None and not self.incremental_nehvi:
580-
self._compute_initial_hvs(obj=obj, feas=feas)
581-
if self.ref_point.shape[-1] > 2:
582-
# the partitioning algorithms run faster on the CPU
583-
# due to advanced indexing
584-
ref_point_cpu = self.ref_point.cpu()
585-
obj_cpu = obj.cpu()
586-
if self.constraints is not None and feas is not None:
587-
feas_cpu = feas.cpu()
588-
obj_cpu = [obj_cpu[i][feas_cpu[i]] for i in range(obj.shape[0])]
589-
partitionings = []
590-
for sample in obj_cpu:
591-
partitioning = self.p_class(
592-
ref_point=ref_point_cpu, Y=sample, **self.p_kwargs
593-
)
594-
partitionings.append(partitioning)
595-
self.partitioning = BoxDecompositionList(*partitionings)
596-
else:
597-
# use batched partitioning
598-
obj = _pad_batch_pareto_frontier(
599-
Y=obj,
600-
ref_point=self.ref_point.unsqueeze(0).expand(
601-
obj.shape[0], self.ref_point.shape[-1]
602-
),
603-
feasibility_mask=feas,
604-
)
605-
self.partitioning = self.p_class(
606-
ref_point=self.ref_point, Y=obj, **self.p_kwargs
607-
)
608-
cell_bounds = self.partitioning.get_hypercell_bounds().to(self.ref_point)
609-
cell_bounds = cell_bounds.view(
610-
2, *self._batch_sample_shape, *cell_bounds.shape[-2:]
611-
)
612-
self.register_buffer("cell_lower_bounds", cell_bounds[0])
613-
self.register_buffer("cell_upper_bounds", cell_bounds[1])
614-
615-
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
616-
r"""Informs the acquisition function about pending design points.
617-
618-
Args:
619-
X_pending: `n x d` Tensor with `n` `d`-dim design points that have
620-
been submitted for evaluation but have not yet been evaluated.
621-
"""
622-
if X_pending is None:
623-
self.X_pending = None
624-
else:
625-
if X_pending.requires_grad:
626-
warnings.warn(
627-
"Pending points require a gradient but the acquisition function"
628-
" will not provide a gradient to these points.",
629-
BotorchWarning,
630-
)
631-
X_pending = X_pending.detach().clone()
632-
if self.cache_pending:
633-
X_baseline = torch.cat([self._X_baseline, X_pending], dim=-2)
634-
# Number of new points is the total number of points minus
635-
# (the number of previously cached pending points plus the
636-
# of number of baseline points).
637-
num_new_points = X_baseline.shape[0] - self.X_baseline.shape[0]
638-
if num_new_points > 0:
639-
if num_new_points > self._max_iep:
640-
# Set the new baseline points to include pending points.
641-
self.register_buffer("_X_baseline_and_pending", X_baseline)
642-
# Recompute box decompositions.
643-
self._set_cell_bounds(num_new_points=num_new_points)
644-
if not self.incremental_nehvi:
645-
self._prev_nehvi = (
646-
(self._hypervolumes - self._initial_hvs)
647-
.clamp_min(0.0)
648-
.mean()
649-
)
650-
# Set to None so that pending points are not concatenated in
651-
# forward.
652-
self.X_pending = None
653-
# Set q_in=-1 to so that self.sampler is updated at the next
654-
# forward call.
655-
self.q_in = -1
656-
else:
657-
self.X_pending = X_pending[-num_new_points:]
658-
else:
659-
self.X_pending = X_pending
660-
661-
@property
662-
def _hypervolumes(self) -> Tensor:
663-
r"""Compute hypervolume over X_baseline under each posterior sample.
664-
665-
Returns:
666-
A `n_samples`-dim tensor of hypervolumes.
667-
"""
668-
return (
669-
self.partitioning.compute_hypervolume()
670-
.to(self.ref_point) # for m > 2, the partitioning is on the CPU
671-
.view(self._batch_sample_shape)
411+
SubsetIndexCachingMixin.__init__(self)
412+
NoisyExpectedHypervolumeMixin.__init__(
413+
self,
414+
model=model,
415+
ref_point=ref_point,
416+
X_baseline=X_baseline,
417+
sampler=self.sampler,
418+
objective=self.objective,
419+
constraints=self.constraints,
420+
X_pending=X_pending,
421+
prune_baseline=prune_baseline,
422+
alpha=alpha,
423+
cache_pending=cache_pending,
424+
max_iep=max_iep,
425+
incremental_nehvi=incremental_nehvi,
426+
cache_root=cache_root,
427+
marginalize_dim=marginalize_dim,
672428
)
673429

674430
@concatenate_pending_points

0 commit comments

Comments
 (0)