Skip to content

LearnedObjective and PairwiseGP dtype fixes and cleanup #2006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.model import Model
from botorch.models.transforms.outcome import Standardize
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
Expand Down Expand Up @@ -522,6 +523,13 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
)


def _get_learned_objective_pref_model_mixed_dtype_warn() -> str:
return (
"pref_model has double-precision data, but single-precision data "
"was passed to the LearnedObjective. Upcasting to double."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make this a constant?



class LearnedObjective(MCAcquisitionObjective):
r"""Learned preference objective constructed from a preference model.

Expand Down Expand Up @@ -576,7 +584,17 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
A `(sample_size * num_samples) x batch_shape x N`-dim Tensor of
objective values sampled from utility posterior using `pref_model`.
"""
post = self.pref_model.posterior(samples)
if samples.dtype == torch.float32 and any(
[d == torch.float64 for d in self.pref_model.dtypes_of_buffers]
):
warnings.warn(
_get_learned_objective_pref_model_mixed_dtype_warn(),
InputDataWarning,
)
samples = samples.to(torch.float64)

posterior_ = self.pref_model.posterior
post = posterior_(samples)
if isinstance(self.pref_model, DeterministicModel):
# return preference posterior mean
return post.mean.squeeze(-1)
Expand Down
10 changes: 10 additions & 0 deletions botorch/exceptions/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ class UserInputWarning(BotorchWarning):
r"""Warning raised when a potential issue is detected with user provided inputs."""

pass


def _get_single_precision_warning(dtype_str: str) -> str:
msg = (
f"The model inputs are of type {dtype_str}. It is strongly recommended "
"to use double precision in BoTorch, as this improves both "
"precision and stability and can help avoid numerical errors. "
"See https://github.com/pytorch/botorch/discussions/1444"
)
return msg
17 changes: 5 additions & 12 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import BotorchTensorDimensionError, InputDataError
from botorch.exceptions.warnings import BotorchTensorDimensionWarning
from botorch.exceptions.warnings import (
_get_single_precision_warning,
BotorchTensorDimensionWarning,
)
from botorch.models.model import Model, ModelList
from botorch.models.utils import (
_make_X_full,
Expand All @@ -44,16 +47,6 @@
from gpytorch.likelihoods import Likelihood # pragma: no cover


def _get_single_precision_warning(dtype: torch.dtype) -> str:
msg = (
f"The model inputs are of type {dtype}. It is strongly recommended "
"to use double precision in BoTorch, as this improves both "
"precision and stability and can help avoid numerical errors. "
"See https://github.com/pytorch/botorch/discussions/1444"
)
return msg


class GPyTorchModel(Model, ABC):
r"""Abstract base class for models based on GPyTorch models.

Expand Down Expand Up @@ -126,7 +119,7 @@ def _validate_tensor_args(
)
if X.dtype != torch.float64:
# NOTE: Not using a BotorchWarning since those get ignored.
warnings.warn(_get_single_precision_warning(X.dtype), UserWarning)
warnings.warn(_get_single_precision_warning(str(X.dtype)), UserWarning)

@property
def batch_shape(self) -> torch.Size:
Expand Down
5 changes: 5 additions & 0 deletions botorch/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
List,
Mapping,
Optional,
Set,
TYPE_CHECKING,
TypeVar,
Union,
Expand Down Expand Up @@ -244,6 +245,10 @@ def train(self, mode: bool = True) -> Model:
self._set_transformed_inputs()
return super().train(mode=mode)

@property
def dtypes_of_buffers(self) -> Set[torch.dtype]:
return {t.dtype for t in self._buffers.values() if t is not None}


class FantasizeMixin(ABC):
"""
Expand Down
50 changes: 35 additions & 15 deletions botorch/models/pairwise_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions import UnsupportedError

from botorch.exceptions.warnings import _get_single_precision_warning, InputDataWarning
from botorch.models.likelihoods.pairwise import (
PairwiseLikelihood,
PairwiseProbitLikelihood,
Expand Down Expand Up @@ -167,12 +169,17 @@ class PairwiseGP(Model, GP, FantasizeMixin):

def __init__(
self,
datapoints: Tensor,
comparisons: Tensor,
datapoints: Optional[Tensor],
comparisons: Optional[Tensor],
Comment on lines +172 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the docstring with what happens when these are None?

likelihood: Optional[PairwiseLikelihood] = None,
covar_module: Optional[ScaleKernel] = None,
input_transform: Optional[InputTransform] = None,
**kwargs,
*,
jitter: float = 1e-6,
xtol: Optional[float] = None,
consolidate_rtol: float = 0.0,
consolidate_atol: float = 1e-4,
maxfev: Optional[int] = None,
) -> None:
r"""
Args:
Expand All @@ -184,22 +191,35 @@ def __init__(
covar_module: Covariance module.
input_transform: An input transform that is applied in the model's
forward pass.
jitter: Value added to diagonal for numerical stability in
`psd_safe_cholesky`.
xtol: Stopping creteria in scipy.optimize.fsolve used to find f_map
in `PairwiseGP._update`. If None, default behavior is handled by
`PairwiseGP._update`.
consolidate_rtol: `rtol` passed to `consolidate_duplicates`.
consolidate_atol: `atol` passed to `consolidate_duplicates`.
maxfev: The maximum number of calls to the function in
scipy.optimize.fsolve. If None, default behavior is handled by
`PairwiseGP._update`.
"""
super().__init__()
# Input data validation
if datapoints is not None and datapoints.dtype == torch.float32:
warnings.warn(
_get_single_precision_warning(str(datapoints.dtype)),
category=InputDataWarning,
)
if comparisons is not None and comparisons.dtype.is_floating_point:
warnings.warn(
"An integer dtype is expected for `comparisons`.", InputDataWarning
)

# Set optional parameters
# Explicitly set jitter for numerical stability in psd_safe_cholesky
self._jitter = kwargs.get("jitter", 1e-6)
# Stopping creteria in scipy.optimize.fsolve used to find f_map in _update()
# If None, set to 1e-6 by default in _update
self._xtol = kwargs.get("xtol")
# atol rtol for consolidate_duplicates
self._consolidate_rtol = kwargs.get("consolidate_rtol", 0)
self._consolidate_atol = kwargs.get("consolidate_atol", 1e-4)
# The maximum number of calls to the function in scipy.optimize.fsolve
# If None, set to 100 by default in _update
# If zero, then 100*(N+1) is used by default by fsolve;
self._maxfev = kwargs.get("maxfev")
self._jitter = jitter
self._xtol = xtol
self._consolidate_rtol = consolidate_rtol
self._consolidate_atol = consolidate_atol
self._maxfev = maxfev

if input_transform is not None:
input_transform.to(datapoints)
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def detect_duplicates(


def consolidate_duplicates(
X: Tensor, Y: Tensor, rtol: float = 0, atol: float = 1e-8
X: Tensor, Y: Tensor, rtol: float = 0.0, atol: float = 1e-8
) -> Tuple[Tensor, Tensor, Tensor]:
"""Drop duplicated Xs and update the indices tensor Y accordingly.
Supporting 2d Tensor only as in batch mode block design is not guaranteed.
Expand Down
128 changes: 104 additions & 24 deletions test/acquisition/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

import itertools
import warnings
from typing import Optional

import torch
from botorch import settings
from botorch.acquisition import LearnedObjective
from botorch.acquisition.objective import (
_get_learned_objective_pref_model_mixed_dtype_warn,
ConstrainedMCObjective,
ExpectationPosteriorTransform,
GenericMCObjective,
Expand All @@ -21,8 +23,10 @@
ScalarizedPosteriorTransform,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import _get_single_precision_warning, InputDataWarning
from botorch.models.deterministic import PosteriorMeanModel
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import Normalize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils import apply_constraints
Expand Down Expand Up @@ -435,45 +439,121 @@ def test_linear_mc_objective(self) -> None:


class TestLearnedObjective(BotorchTestCase):
def test_learned_preference_objective(self):
X_dim = 2
train_X = torch.rand(2, X_dim)
def setUp(self, suppress_input_warnings: bool = False) -> None:
super().setUp(suppress_input_warnings=suppress_input_warnings)
self.x_dim = 2

def _get_pref_model(
self,
dtype: Optional[torch.dtype] = None,
input_transform: Optional[Normalize] = None,
) -> PairwiseGP:
train_X = torch.rand((2, self.x_dim), dtype=dtype)
train_comps = torch.LongTensor([[0, 1]])
pref_model = PairwiseGP(train_X, train_comps)
pref_model = PairwiseGP(train_X, train_comps, input_transform=input_transform)
return pref_model

def test_learned_preference_objective(self) -> None:
pref_model = self._get_pref_model(dtype=torch.float64)

og_sample_shape = 3
batch_size = 2
n = 8
test_X = torch.rand(torch.Size((og_sample_shape, batch_size, n, X_dim)))
test_X = torch.rand(
torch.Size((og_sample_shape, batch_size, n, self.x_dim)),
dtype=torch.float64,
)

# test default setting where sampler =
# IIDNormalSampler(sample_shape=torch.Size([1]))
pref_obj = LearnedObjective(pref_model=pref_model)
self.assertEqual(
pref_obj(test_X).shape, torch.Size([og_sample_shape, batch_size, n])
)
with self.subTest("default sampler"):
pref_obj = LearnedObjective(pref_model=pref_model)
first_call_output = pref_obj(test_X)
self.assertEqual(
first_call_output.shape, torch.Size([og_sample_shape, batch_size, n])
)

# test when sampler has num_samples = 16
num_samples = 16
pref_obj = LearnedObjective(
pref_model=pref_model,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])),
)
self.assertEqual(
pref_obj(test_X).shape,
torch.Size([num_samples * og_sample_shape, batch_size, n]),
)
with self.subTest("SobolQMCNormalSampler"):
num_samples = 16
pref_obj = LearnedObjective(
pref_model=pref_model,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])),
)
self.assertEqual(
pref_obj(test_X).shape,
torch.Size([num_samples * og_sample_shape, batch_size, n]),
)

# test posterior mean
mean_pref_model = PosteriorMeanModel(model=pref_model)
pref_obj = LearnedObjective(pref_model=mean_pref_model)
self.assertEqual(
pref_obj(test_X).shape, torch.Size([og_sample_shape, batch_size, n])
)
with self.subTest("PosteriorMeanModel"):
mean_pref_model = PosteriorMeanModel(model=pref_model)
pref_obj = LearnedObjective(pref_model=mean_pref_model)
self.assertEqual(
pref_obj(test_X).shape, torch.Size([og_sample_shape, batch_size, n])
)

# cannot use a deterministic model together with a sampler
with self.assertRaises(AssertionError):
with self.subTest("deterministic model"), self.assertRaises(AssertionError):
LearnedObjective(
pref_model=mean_pref_model,
sampler=SobolQMCNormalSampler(sample_shape=torch.Size([num_samples])),
)

def test_dtype_compatibility_with_PairwiseGP(self) -> None:
og_sample_shape = 3
batch_size = 2
n = 8

test_X = torch.rand(
torch.Size((og_sample_shape, batch_size, n, self.x_dim)),
)

for pref_model_dtype, test_x_dtype, expected_output_dtype in [
(torch.float64, torch.float64, torch.float64),
(torch.float32, torch.float32, torch.float32),
(torch.float64, torch.float32, torch.float64),
]:
with self.subTest(
"numerical behavior",
pref_model_dtype=pref_model_dtype,
test_x_dtype=test_x_dtype,
expected_output_dtype=expected_output_dtype,
):
# Ignore a single-precision warning in PairwiseGP
# and mixed-precision warning tested below
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=InputDataWarning,
message=_get_single_precision_warning(str(torch.float32)),
)
pref_model = self._get_pref_model(
dtype=pref_model_dtype,
input_transform=Normalize(d=2),
)
pref_obj = LearnedObjective(pref_model=pref_model)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=InputDataWarning,
message=_get_learned_objective_pref_model_mixed_dtype_warn(),
)
first_call_output = pref_obj(test_X.to(dtype=test_x_dtype))
second_call_output = pref_obj(test_X.to(dtype=test_x_dtype))

self.assertEqual(first_call_output.dtype, expected_output_dtype)
self.assertTrue(torch.equal(first_call_output, second_call_output))

with self.subTest("mixed precision warning"):
# should warn and test should pass
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=InputDataWarning)
pref_model = self._get_pref_model(
dtype=torch.float64, input_transform=Normalize(d=2)
)
pref_obj = LearnedObjective(pref_model=pref_model)
with self.assertWarnsRegex(
InputDataWarning, _get_learned_objective_pref_model_mixed_dtype_warn()
):
first_call_output = pref_obj(test_X)
Loading