|
23 | 23 |
|
24 | 24 | from __future__ import annotations
|
25 | 25 |
|
26 |
| -import warnings |
27 | 26 | from abc import ABC, abstractmethod
|
28 |
| -from copy import deepcopy |
29 | 27 | from typing import Callable, List, Optional, Union
|
30 | 28 |
|
31 | 29 | import torch
|
32 | 30 | from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
|
33 |
| -from botorch.acquisition.cached_cholesky import CachedCholeskyMCSamplerMixin |
34 | 31 | from botorch.acquisition.multi_objective.objective import (
|
35 | 32 | IdentityMCMultiOutputObjective,
|
36 | 33 | MCMultiOutputObjective,
|
37 | 34 | )
|
38 |
| -from botorch.acquisition.multi_objective.utils import ( |
39 |
| - prune_inferior_points_multi_objective, |
40 |
| -) |
41 | 35 | from botorch.exceptions.errors import UnsupportedError
|
42 |
| -from botorch.exceptions.warnings import BotorchWarning |
43 | 36 | from botorch.models.model import Model
|
44 | 37 | from botorch.models.transforms.input import InputPerturbation
|
45 | 38 | from botorch.sampling.base import MCSampler
|
46 |
| -from botorch.utils.multi_objective.box_decompositions.box_decomposition_list import ( |
47 |
| - BoxDecompositionList, |
48 |
| -) |
49 |
| -from botorch.utils.multi_objective.box_decompositions.dominated import ( |
50 |
| - DominatedPartitioning, |
51 |
| -) |
52 | 39 | from botorch.utils.multi_objective.box_decompositions.non_dominated import (
|
53 |
| - FastNondominatedPartitioning, |
54 | 40 | NondominatedPartitioning,
|
55 | 41 | )
|
56 |
| -from botorch.utils.multi_objective.box_decompositions.utils import ( |
57 |
| - _pad_batch_pareto_frontier, |
| 42 | +from botorch.utils.multi_objective.hypervolume import ( |
| 43 | + NoisyExpectedHypervolumeMixin, |
| 44 | + SubsetIndexCachingMixin, |
58 | 45 | )
|
59 |
| -from botorch.utils.multi_objective.hypervolume import SubsetIndexCachingMixin |
60 | 46 | from botorch.utils.objective import compute_smoothed_feasibility_indicator
|
61 |
| -from botorch.utils.torch import BufferDict |
62 | 47 | from botorch.utils.transforms import (
|
63 | 48 | concatenate_pending_points,
|
64 | 49 | is_fully_bayesian,
|
@@ -250,7 +235,9 @@ def _compute_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
|
250 | 235 | q = obj.shape[-2]
|
251 | 236 | if self.constraints is not None:
|
252 | 237 | feas_weights = compute_smoothed_feasibility_indicator(
|
253 |
| - constraints=self.constraints, samples=samples, eta=self.eta |
| 238 | + constraints=self.constraints, |
| 239 | + samples=samples, |
| 240 | + eta=self.eta, |
254 | 241 | ) # `sample_shape x batch-shape x q`
|
255 | 242 | device = self.ref_point.device
|
256 | 243 | q_subset_indices = self.compute_q_subset_indices(q_out=q, device=device)
|
@@ -326,7 +313,7 @@ def forward(self, X: Tensor) -> Tensor:
|
326 | 313 |
|
327 | 314 |
|
328 | 315 | class qNoisyExpectedHypervolumeImprovement(
|
329 |
| - qExpectedHypervolumeImprovement, CachedCholeskyMCSamplerMixin |
| 316 | + NoisyExpectedHypervolumeMixin, qExpectedHypervolumeImprovement |
330 | 317 | ):
|
331 | 318 | def __init__(
|
332 | 319 | self,
|
@@ -407,268 +394,33 @@ def __init__(
|
407 | 394 | `q` points.
|
408 | 395 | cache_root: A boolean indicating whether to cache the root
|
409 | 396 | decomposition over `X_baseline` and use low-rank updates.
|
| 397 | + marginalize_dim: A batch dimension that should be marginalized. |
410 | 398 | """
|
411 |
| - if len(ref_point) < 2: |
412 |
| - raise ValueError( |
413 |
| - "qNoisyExpectedHypervolumeImprovement supports m>=2 outcomes " |
414 |
| - f"but ref_point has length {len(ref_point)}, which is smaller than 2." |
415 |
| - ) |
416 |
| - ref_point = torch.as_tensor( |
417 |
| - ref_point, dtype=X_baseline.dtype, device=X_baseline.device |
418 |
| - ) |
419 |
| - super(qExpectedHypervolumeImprovement, self).__init__( |
| 399 | + MultiObjectiveMCAcquisitionFunction.__init__( |
| 400 | + self, |
420 | 401 | model=model,
|
421 | 402 | sampler=sampler,
|
422 | 403 | objective=objective,
|
423 | 404 | constraints=constraints,
|
424 | 405 | eta=eta,
|
425 | 406 | )
|
426 |
| - CachedCholeskyMCSamplerMixin.__init__( |
427 |
| - self, model=model, cache_root=cache_root, sampler=sampler |
428 |
| - ) |
429 |
| - |
430 |
| - if X_baseline.ndim > 2: |
431 |
| - raise UnsupportedError( |
432 |
| - "qNoisyExpectedHypervolumeImprovement does not support batched " |
433 |
| - f"X_baseline. Expected 2 dims, got {X_baseline.ndim}." |
434 |
| - ) |
435 |
| - if prune_baseline: |
436 |
| - X_baseline = prune_inferior_points_multi_objective( |
437 |
| - model=model, |
438 |
| - X=X_baseline, |
439 |
| - objective=objective, |
440 |
| - constraints=constraints, |
441 |
| - ref_point=ref_point, |
442 |
| - marginalize_dim=marginalize_dim, |
443 |
| - ) |
444 |
| - self.register_buffer("ref_point", ref_point) |
445 |
| - self.alpha = alpha |
446 |
| - self.q_in = -1 |
447 |
| - self.q_out = -1 |
448 |
| - self.q_subset_indices = BufferDict() |
449 |
| - self.partitioning = None |
450 |
| - # set partitioning class and args |
451 |
| - self.p_kwargs = {} |
452 |
| - if self.alpha > 0: |
453 |
| - self.p_kwargs["alpha"] = self.alpha |
454 |
| - self.p_class = NondominatedPartitioning |
455 |
| - else: |
456 |
| - self.p_class = FastNondominatedPartitioning |
457 |
| - self.register_buffer("_X_baseline", X_baseline) |
458 |
| - self.register_buffer("_X_baseline_and_pending", X_baseline) |
459 |
| - self.register_buffer( |
460 |
| - "cache_pending", |
461 |
| - torch.tensor(cache_pending, dtype=bool), |
462 |
| - ) |
463 |
| - self.register_buffer( |
464 |
| - "_prev_nehvi", |
465 |
| - torch.tensor(0.0, dtype=ref_point.dtype, device=ref_point.device), |
466 |
| - ) |
467 |
| - self.register_buffer( |
468 |
| - "_max_iep", |
469 |
| - torch.tensor(max_iep, dtype=torch.long), |
470 |
| - ) |
471 |
| - self.register_buffer( |
472 |
| - "incremental_nehvi", |
473 |
| - torch.tensor(incremental_nehvi, dtype=torch.bool), |
474 |
| - ) |
475 |
| - |
476 |
| - # Base sampler is initialized in _set_cell_bounds. |
477 |
| - self.base_sampler = None |
478 |
| - |
479 |
| - if X_pending is not None: |
480 |
| - # This will call self._set_cell_bounds if the number of pending |
481 |
| - # points is greater than self._max_iep. |
482 |
| - self.set_X_pending(X_pending) |
483 |
| - # In the case that X_pending is not None, but there are fewer than |
484 |
| - # max_iep pending points, the box decompositions are not performed in |
485 |
| - # set_X_pending. Therefore, we need to perform a box decomposition over |
486 |
| - # f(X_baseline) here. |
487 |
| - if X_pending is None or X_pending.shape[-2] <= self._max_iep: |
488 |
| - self._set_cell_bounds(num_new_points=X_baseline.shape[0]) |
489 |
| - # Set q_in=-1 to so that self.sampler is updated at the next forward call. |
490 |
| - self.q_in = -1 |
491 |
| - |
492 |
| - @property |
493 |
| - def X_baseline(self) -> Tensor: |
494 |
| - r"""Return X_baseline augmented with pending points cached using CBD.""" |
495 |
| - return self._X_baseline_and_pending |
496 |
| - |
497 |
| - def _compute_initial_hvs(self, obj: Tensor, feas: Optional[Tensor] = None) -> None: |
498 |
| - r"""Compute hypervolume dominated by f(X_baseline) under each sample. |
499 |
| -
|
500 |
| - Args: |
501 |
| - obj: A `sample_shape x batch_shape x n x m`-dim tensor of samples |
502 |
| - of objectives. |
503 |
| - feas: `sample_shape x batch_shape x n`-dim tensor of samples |
504 |
| - of feasibility indicators. |
505 |
| - """ |
506 |
| - initial_hvs = [] |
507 |
| - for i, sample in enumerate(obj): |
508 |
| - if self.constraints is not None: |
509 |
| - sample = sample[feas[i]] |
510 |
| - dominated_partitioning = DominatedPartitioning( |
511 |
| - ref_point=self.ref_point, |
512 |
| - Y=sample, |
513 |
| - ) |
514 |
| - hv = dominated_partitioning.compute_hypervolume() |
515 |
| - initial_hvs.append(hv) |
516 |
| - self.register_buffer( |
517 |
| - "_initial_hvs", |
518 |
| - torch.tensor(initial_hvs, dtype=obj.dtype, device=obj.device).view( |
519 |
| - self._batch_sample_shape, *obj.shape[-2:] |
520 |
| - ), |
521 |
| - ) |
522 |
| - |
523 |
| - def _set_cell_bounds(self, num_new_points: int) -> None: |
524 |
| - r"""Compute the box decomposition under each posterior sample. |
525 |
| -
|
526 |
| - Args: |
527 |
| - num_new_points: The number of new points (beyond the points |
528 |
| - in X_baseline) that were used in the previous box decomposition. |
529 |
| - In the first box decomposition, this should be the number of points |
530 |
| - in X_baseline. |
531 |
| - """ |
532 |
| - feas = None |
533 |
| - if self.X_baseline.shape[0] > 0: |
534 |
| - with torch.no_grad(): |
535 |
| - posterior = self.model.posterior(self.X_baseline) |
536 |
| - # Reset sampler, accounting for possible one-to-many transform. |
537 |
| - self.q_in = -1 |
538 |
| - if self.base_sampler is None: |
539 |
| - # Initialize the base sampler if needed. |
540 |
| - samples = self.get_posterior_samples(posterior) |
541 |
| - self.base_sampler = deepcopy(self.sampler) |
542 |
| - else: |
543 |
| - samples = self.base_sampler(posterior) |
544 |
| - n_w = posterior._extended_shape()[-2] // self.X_baseline.shape[-2] |
545 |
| - self._set_sampler(q_in=num_new_points * n_w, posterior=posterior) |
546 |
| - # cache posterior |
547 |
| - if self._cache_root: |
548 |
| - # Note that this implicitly uses LinearOperator's caching to check if |
549 |
| - # the proper root decomposition has already been cached to |
550 |
| - # `posterior.mvn.lazy_covariance_matrix`, which it may have been in |
551 |
| - # the call to `self.base_sampler`, and computes it if not found |
552 |
| - self._baseline_L = self._compute_root_decomposition(posterior=posterior) |
553 |
| - obj = self.objective(samples, X=self.X_baseline) |
554 |
| - if self.constraints is not None: |
555 |
| - feas = torch.stack( |
556 |
| - [c(samples) <= 0 for c in self.constraints], dim=0 |
557 |
| - ).all(dim=0) |
558 |
| - else: |
559 |
| - sample_shape = ( |
560 |
| - self.sampler.sample_shape |
561 |
| - if self.sampler is not None |
562 |
| - else self._default_sample_shape |
563 |
| - ) |
564 |
| - obj = torch.empty( |
565 |
| - *sample_shape, |
566 |
| - 0, |
567 |
| - self.ref_point.shape[-1], |
568 |
| - dtype=self.ref_point.dtype, |
569 |
| - device=self.ref_point.device, |
570 |
| - ) |
571 |
| - self._batch_sample_shape = obj.shape[:-2] |
572 |
| - # collapse batch dimensions |
573 |
| - # use numel() rather than view(-1) to handle case of no baseline points |
574 |
| - new_batch_shape = self._batch_sample_shape.numel() |
575 |
| - obj = obj.view(new_batch_shape, *obj.shape[-2:]) |
576 |
| - if self.constraints is not None and feas is not None: |
577 |
| - feas = feas.view(new_batch_shape, *feas.shape[-1:]) |
578 |
| - |
579 |
| - if self.partitioning is None and not self.incremental_nehvi: |
580 |
| - self._compute_initial_hvs(obj=obj, feas=feas) |
581 |
| - if self.ref_point.shape[-1] > 2: |
582 |
| - # the partitioning algorithms run faster on the CPU |
583 |
| - # due to advanced indexing |
584 |
| - ref_point_cpu = self.ref_point.cpu() |
585 |
| - obj_cpu = obj.cpu() |
586 |
| - if self.constraints is not None and feas is not None: |
587 |
| - feas_cpu = feas.cpu() |
588 |
| - obj_cpu = [obj_cpu[i][feas_cpu[i]] for i in range(obj.shape[0])] |
589 |
| - partitionings = [] |
590 |
| - for sample in obj_cpu: |
591 |
| - partitioning = self.p_class( |
592 |
| - ref_point=ref_point_cpu, Y=sample, **self.p_kwargs |
593 |
| - ) |
594 |
| - partitionings.append(partitioning) |
595 |
| - self.partitioning = BoxDecompositionList(*partitionings) |
596 |
| - else: |
597 |
| - # use batched partitioning |
598 |
| - obj = _pad_batch_pareto_frontier( |
599 |
| - Y=obj, |
600 |
| - ref_point=self.ref_point.unsqueeze(0).expand( |
601 |
| - obj.shape[0], self.ref_point.shape[-1] |
602 |
| - ), |
603 |
| - feasibility_mask=feas, |
604 |
| - ) |
605 |
| - self.partitioning = self.p_class( |
606 |
| - ref_point=self.ref_point, Y=obj, **self.p_kwargs |
607 |
| - ) |
608 |
| - cell_bounds = self.partitioning.get_hypercell_bounds().to(self.ref_point) |
609 |
| - cell_bounds = cell_bounds.view( |
610 |
| - 2, *self._batch_sample_shape, *cell_bounds.shape[-2:] |
611 |
| - ) |
612 |
| - self.register_buffer("cell_lower_bounds", cell_bounds[0]) |
613 |
| - self.register_buffer("cell_upper_bounds", cell_bounds[1]) |
614 |
| - |
615 |
| - def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: |
616 |
| - r"""Informs the acquisition function about pending design points. |
617 |
| -
|
618 |
| - Args: |
619 |
| - X_pending: `n x d` Tensor with `n` `d`-dim design points that have |
620 |
| - been submitted for evaluation but have not yet been evaluated. |
621 |
| - """ |
622 |
| - if X_pending is None: |
623 |
| - self.X_pending = None |
624 |
| - else: |
625 |
| - if X_pending.requires_grad: |
626 |
| - warnings.warn( |
627 |
| - "Pending points require a gradient but the acquisition function" |
628 |
| - " will not provide a gradient to these points.", |
629 |
| - BotorchWarning, |
630 |
| - ) |
631 |
| - X_pending = X_pending.detach().clone() |
632 |
| - if self.cache_pending: |
633 |
| - X_baseline = torch.cat([self._X_baseline, X_pending], dim=-2) |
634 |
| - # Number of new points is the total number of points minus |
635 |
| - # (the number of previously cached pending points plus the |
636 |
| - # of number of baseline points). |
637 |
| - num_new_points = X_baseline.shape[0] - self.X_baseline.shape[0] |
638 |
| - if num_new_points > 0: |
639 |
| - if num_new_points > self._max_iep: |
640 |
| - # Set the new baseline points to include pending points. |
641 |
| - self.register_buffer("_X_baseline_and_pending", X_baseline) |
642 |
| - # Recompute box decompositions. |
643 |
| - self._set_cell_bounds(num_new_points=num_new_points) |
644 |
| - if not self.incremental_nehvi: |
645 |
| - self._prev_nehvi = ( |
646 |
| - (self._hypervolumes - self._initial_hvs) |
647 |
| - .clamp_min(0.0) |
648 |
| - .mean() |
649 |
| - ) |
650 |
| - # Set to None so that pending points are not concatenated in |
651 |
| - # forward. |
652 |
| - self.X_pending = None |
653 |
| - # Set q_in=-1 to so that self.sampler is updated at the next |
654 |
| - # forward call. |
655 |
| - self.q_in = -1 |
656 |
| - else: |
657 |
| - self.X_pending = X_pending[-num_new_points:] |
658 |
| - else: |
659 |
| - self.X_pending = X_pending |
660 |
| - |
661 |
| - @property |
662 |
| - def _hypervolumes(self) -> Tensor: |
663 |
| - r"""Compute hypervolume over X_baseline under each posterior sample. |
664 |
| -
|
665 |
| - Returns: |
666 |
| - A `n_samples`-dim tensor of hypervolumes. |
667 |
| - """ |
668 |
| - return ( |
669 |
| - self.partitioning.compute_hypervolume() |
670 |
| - .to(self.ref_point) # for m > 2, the partitioning is on the CPU |
671 |
| - .view(self._batch_sample_shape) |
| 407 | + SubsetIndexCachingMixin.__init__(self) |
| 408 | + NoisyExpectedHypervolumeMixin.__init__( |
| 409 | + self, |
| 410 | + model=model, |
| 411 | + ref_point=ref_point, |
| 412 | + X_baseline=X_baseline, |
| 413 | + sampler=self.sampler, |
| 414 | + objective=self.objective, |
| 415 | + constraints=self.constraints, |
| 416 | + X_pending=X_pending, |
| 417 | + prune_baseline=prune_baseline, |
| 418 | + alpha=alpha, |
| 419 | + cache_pending=cache_pending, |
| 420 | + max_iep=max_iep, |
| 421 | + incremental_nehvi=incremental_nehvi, |
| 422 | + cache_root=cache_root, |
| 423 | + marginalize_dim=marginalize_dim, |
672 | 424 | )
|
673 | 425 |
|
674 | 426 | @concatenate_pending_points
|
|
0 commit comments