Skip to content

Commit dc219ca

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Increase code-sharing of LCEMGP & define construct_inputs (#2291)
Summary: Pull Request resolved: #2291 Increases code sharing between LCEMGP & the parent MultiTaskGP: - Allows customizing mean, covariance & likelihood modules. - Eliminates duplicate `forward` implementation by renaming `task_covar_matrix` to `task_covar_module`. Defines a `construct_inputs` method for LCEMGP that supports the kwargs used to customize the task covariance module (which differ from those used for MultiTaskGP). Reviewed By: Balandat Differential Revision: D55935507 fbshipit-source-id: ed6fc6e47eeb02d0dddd7657df809f9623142a81
1 parent c9966e9 commit dc219ca

File tree

3 files changed

+129
-64
lines changed

3 files changed

+129
-64
lines changed

botorch/models/contextual_multioutput.py

+73-21
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@
1414
"""
1515

1616
import warnings
17-
from typing import List, Optional
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
import torch
2020
from botorch.models.multitask import MultiTaskGP
2121
from botorch.models.transforms.input import InputTransform
2222
from botorch.models.transforms.outcome import OutcomeTransform
23+
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
2324
from gpytorch.constraints import Interval
24-
from gpytorch.distributions.multivariate_normal import MultivariateNormal
2525
from gpytorch.kernels.rbf_kernel import RBFKernel
26+
from gpytorch.likelihoods.likelihood import Likelihood
27+
from gpytorch.module import Module
2628
from linear_operator.operators import LinearOperator
2729
from torch import Tensor
2830
from torch.nn import ModuleList
@@ -41,6 +43,9 @@ def __init__(
4143
train_Y: Tensor,
4244
task_feature: int,
4345
train_Yvar: Optional[Tensor] = None,
46+
mean_module: Optional[Module] = None,
47+
covar_module: Optional[Module] = None,
48+
likelihood: Optional[Likelihood] = None,
4449
context_cat_feature: Optional[Tensor] = None,
4550
context_emb_feature: Optional[Tensor] = None,
4651
embs_dim_list: Optional[List[int]] = None,
@@ -57,6 +62,12 @@ def __init__(
5762
train_Yvar: An optional (n x 1) tensor of observed variances of each
5863
training Y. If None, we infer the noise. Note that the inferred noise
5964
is common across all tasks.
65+
mean_module: The mean function to be used. Defaults to `ConstantMean`.
66+
covar_module: The module for computing the covariance matrix between
67+
the non-task features. Defaults to `MaternKernel`.
68+
likelihood: A likelihood. The default is selected based on `train_Yvar`.
69+
If `train_Yvar` is None, a standard `GaussianLikelihood` with inferred
70+
noise level is used. Otherwise, a FixedNoiseGaussianLikelihood is used.
6071
context_cat_feature: (n_contexts x k) one-hot encoded context
6172
features. Rows are ordered by context indices, where k is the
6273
number of categorical variables. If None, task indices will
@@ -74,29 +85,40 @@ def __init__(
7485
training data. Note that when a task is not observed, the corresponding
7586
task covariance will heavily depend on random initialization and may
7687
behave unexpectedly.
88+
input_transform: An input transform that is applied in the model's
89+
forward pass.
90+
outcome_transform: An outcome transform that is applied to the
91+
training data during instantiation and to the posterior during
92+
inference (that is, the `Posterior` obtained by calling
93+
`.posterior` on the model will be on the original scale).
7794
"""
7895
super().__init__(
7996
train_X=train_X,
8097
train_Y=train_Y,
8198
task_feature=task_feature,
8299
train_Yvar=train_Yvar,
100+
mean_module=mean_module,
101+
covar_module=covar_module,
102+
likelihood=likelihood,
83103
output_tasks=output_tasks,
84104
all_tasks=all_tasks,
85105
input_transform=input_transform,
86106
outcome_transform=outcome_transform,
87107
)
88108
self.device = train_X.device
89109
if all_tasks is None:
90-
all_tasks = train_X[:, task_feature].unique()
91-
self.all_tasks = all_tasks.to(dtype=torch.long).tolist()
110+
all_tasks_tensor = train_X[:, task_feature].unique()
111+
self.all_tasks = all_tasks_tensor.to(dtype=torch.long).tolist()
92112
else:
93-
all_tasks = torch.tensor(all_tasks, dtype=torch.long)
113+
all_tasks_tensor = torch.tensor(all_tasks, dtype=torch.long)
94114
self.all_tasks = all_tasks
95115
self.all_tasks.sort() # These are the context indices.
96116

97117
if context_cat_feature is None:
98-
context_cat_feature = all_tasks.unsqueeze(-1).to(device=self.device)
99-
self.context_cat_feature = context_cat_feature # row indices = context indices
118+
context_cat_feature = all_tasks_tensor.unsqueeze(-1).to(device=self.device)
119+
self.context_cat_feature: Tensor = (
120+
context_cat_feature # row indices = context indices
121+
)
100122
self.context_emb_feature = context_emb_feature
101123

102124
# construct emb_dims based on categorical features
@@ -115,7 +137,7 @@ def __init__(
115137
for x, y in self.emb_dims
116138
]
117139
)
118-
self.task_covar_module = RBFKernel(
140+
self.task_covar_module_base = RBFKernel(
119141
ard_num_dims=n_embs,
120142
lengthscale_constraint=Interval(
121143
0.0, 2.0, transform=None, initial_value=1.0
@@ -132,7 +154,7 @@ def _eval_context_covar(self) -> LinearOperator:
132154
to get the task covariance matrix.
133155
"""
134156
all_embs = self._task_embeddings()
135-
return self.task_covar_module(all_embs)
157+
return self.task_covar_module_base(all_embs)
136158

137159
def _task_embeddings(self) -> Tensor:
138160
"""Generate embedding features for all contexts."""
@@ -154,7 +176,7 @@ def _task_embeddings(self) -> Tensor:
154176
)
155177
return embeddings
156178

157-
def task_covar_matrix(self, task_idcs: Tensor) -> Tensor:
179+
def task_covar_module(self, task_idcs: Tensor) -> Tensor:
158180
r"""Compute the task covariance matrix for a given tensor of
159181
task / context indices.
160182
@@ -184,17 +206,47 @@ def task_covar_matrix(self, task_idcs: Tensor) -> Tensor:
184206
covar_matrix[base_idx].transpose(-1, -2).gather(index=expanded_idx, dim=-2)
185207
)
186208

187-
def forward(self, x: Tensor) -> MultivariateNormal:
188-
if self.training:
189-
x = self.transform_inputs(x)
190-
x_basic, task_idcs = self._split_inputs(x)
191-
# Compute base mean and covariance
192-
mean_x = self.mean_module(x_basic)
193-
covar_x = self.covar_module(x_basic)
194-
# Compute task covariances
195-
covar_i = self.task_covar_matrix(task_idcs)
196-
covar = covar_x.mul(covar_i)
197-
return MultivariateNormal(mean_x, covar)
209+
@classmethod
210+
def construct_inputs(
211+
cls,
212+
training_data: Union[SupervisedDataset, MultiTaskDataset],
213+
task_feature: int,
214+
output_tasks: Optional[List[int]] = None,
215+
context_cat_feature: Optional[Tensor] = None,
216+
context_emb_feature: Optional[Tensor] = None,
217+
embs_dim_list: Optional[List[int]] = None,
218+
**kwargs,
219+
) -> Dict[str, Any]:
220+
r"""Construct `Model` keyword arguments from a dataset and other args.
221+
222+
Args:
223+
training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
224+
task_feature: Column index of embedded task indicator features.
225+
output_tasks: A list of task indices for which to compute model
226+
outputs for. If omitted, return outputs for all task indices.
227+
context_cat_feature: (n_contexts x k) one-hot encoded context
228+
features. Rows are ordered by context indices, where k is the
229+
number of categorical variables. If None, task indices will
230+
be used and k = 1.
231+
context_emb_feature: (n_contexts x m) pre-given continuous
232+
embedding features. Rows are ordered by context indices.
233+
embs_dim_list: Embedding dimension for each categorical variable.
234+
The length equals k. If None, the embedding dimension is set to 1
235+
for each categorical variable.
236+
"""
237+
base_inputs = super().construct_inputs(
238+
training_data=training_data,
239+
task_feature=task_feature,
240+
output_tasks=output_tasks,
241+
**kwargs,
242+
)
243+
if context_cat_feature is not None:
244+
base_inputs["context_cat_feature"] = context_cat_feature
245+
if context_emb_feature is not None:
246+
base_inputs["context_emb_feature"] = context_emb_feature
247+
if embs_dim_list is not None:
248+
base_inputs["embs_dim_list"] = embs_dim_list
249+
return base_inputs
198250

199251

200252
class FixedNoiseLCEMGP(LCEMGP):

botorch/utils/test_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def standardize_moments(
6767

6868
def gen_multi_task_dataset(
6969
yvar: Optional[float] = None, task_values: Optional[List[int]] = None, **tkwargs
70-
) -> Tuple[MultiTaskDataset, Tuple[Tensor, Tensor, Tensor]]:
70+
) -> Tuple[MultiTaskDataset, Tuple[Tensor, Tensor, Optional[Tensor]]]:
7171
"""Constructs a multi-task dataset with two tasks, each with 10 data points."""
7272
X = torch.linspace(0, 0.95, 10, **tkwargs) + 0.05 * torch.rand(10, **tkwargs)
7373
X = X.unsqueeze(dim=-1)

test/models/test_contextual_multioutput.py

+55-42
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from botorch.models.contextual_multioutput import FixedNoiseLCEMGP, LCEMGP
1111
from botorch.models.multitask import MultiTaskGP
1212
from botorch.posteriors import GPyTorchPosterior
13+
from botorch.utils.test_helpers import gen_multi_task_dataset
1314
from botorch.utils.testing import BotorchTestCase
1415
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
1516
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
@@ -22,27 +23,15 @@
2223

2324
class ContextualMultiOutputTest(BotorchTestCase):
2425
def test_LCEMGP(self):
25-
d = 1
2626
for dtype, fixed_noise in ((torch.float, True), (torch.double, False)):
27-
# test with batch evaluation
28-
train_x = torch.rand(10, d, device=self.device, dtype=dtype)
29-
train_y = torch.cos(train_x)
30-
# 2 contexts here
31-
task_indices = torch.tensor(
32-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0],
33-
device=self.device,
34-
dtype=dtype,
27+
_, (train_x, train_y, train_yvar) = gen_multi_task_dataset(
28+
yvar=0.01 if fixed_noise else None, dtype=dtype, device=self.device
3529
)
36-
train_x = torch.cat([train_x, task_indices.unsqueeze(-1)], axis=1)
37-
38-
if fixed_noise:
39-
train_yvar = torch.ones(10, 1, device=self.device, dtype=dtype) * 0.01
40-
else:
41-
train_yvar = None
30+
task_feature = 0
4231
model = LCEMGP(
4332
train_X=train_x,
4433
train_Y=train_y,
45-
task_feature=d,
34+
task_feature=task_feature,
4635
train_Yvar=train_yvar,
4736
)
4837

@@ -65,20 +54,18 @@ def test_LCEMGP(self):
6554
self.assertIsInstance(embeddings, Tensor)
6655
self.assertEqual(embeddings.shape, torch.Size([2, 1]))
6756

68-
test_x = torch.rand(5, d, device=self.device, dtype=dtype)
69-
task_indices = torch.tensor(
70-
[0.0, 0.0, 0.0, 0.0, 0.0], device=self.device, dtype=dtype
71-
)
72-
test_x = torch.cat([test_x, task_indices.unsqueeze(-1)], axis=1)
57+
test_x = train_x[:5]
7358
self.assertIsInstance(model(test_x), MultivariateNormal)
7459

7560
# test posterior
76-
posterior_f = model.posterior(test_x[:, :d])
61+
posterior_f = model.posterior(test_x[:, task_feature + 1 :])
7762
self.assertIsInstance(posterior_f, GPyTorchPosterior)
7863
self.assertIsInstance(posterior_f.distribution, MultitaskMultivariateNormal)
7964

8065
# test posterior w/ single output index
81-
posterior_f = model.posterior(test_x[:, :d], output_indices=[0])
66+
posterior_f = model.posterior(
67+
test_x[:, task_feature + 1 :], output_indices=[0]
68+
)
8269
self.assertIsInstance(posterior_f, GPyTorchPosterior)
8370
self.assertIsInstance(posterior_f.distribution, MultivariateNormal)
8471

@@ -87,9 +74,9 @@ def test_LCEMGP(self):
8774
model2 = LCEMGP(
8875
train_X=train_x,
8976
train_Y=train_y,
90-
task_feature=d,
77+
task_feature=task_feature,
9178
embs_dim_list=[2], # increase dim from 1 to 2
92-
context_emb_feature=torch.Tensor([[0.2], [0.3]]),
79+
context_emb_feature=torch.tensor([[0.2], [0.3]]),
9380
)
9481
self.assertIsInstance(model2, LCEMGP)
9582
self.assertIsInstance(model2, MultiTaskGP)
@@ -113,37 +100,63 @@ def test_LCEMGP(self):
113100
left_interp_indices=task_idcs,
114101
right_interp_indices=task_idcs,
115102
).to_dense()
116-
self.assertAllClose(previous_covar, model.task_covar_matrix(task_idcs))
103+
self.assertAllClose(previous_covar, model.task_covar_module(task_idcs))
117104

118105
def test_FixedNoiseLCEMGP(self):
119-
d = 1
120106
for dtype in (torch.float, torch.double):
121-
train_x = torch.rand(10, d, device=self.device, dtype=dtype)
122-
train_y = torch.cos(train_x)
123-
task_indices = torch.tensor(
124-
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0], device=self.device
107+
_, (train_x, train_y, train_yvar) = gen_multi_task_dataset(
108+
yvar=0.01, dtype=dtype, device=self.device
125109
)
126-
train_x = torch.cat([train_x, task_indices.unsqueeze(-1)], axis=1)
127-
train_yvar = torch.ones(10, 1, device=self.device, dtype=dtype) * 0.01
128110

129111
with self.assertWarnsRegex(DeprecationWarning, "FixedNoiseLCEMGP"):
130112
model = FixedNoiseLCEMGP(
131113
train_X=train_x,
132114
train_Y=train_y,
133115
train_Yvar=train_yvar,
134-
task_feature=d,
116+
task_feature=0,
135117
)
136118
mll = ExactMarginalLogLikelihood(model.likelihood, model)
137119
fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 1}})
138-
139120
self.assertIsInstance(model, FixedNoiseLCEMGP)
140121

141-
test_x = torch.rand(5, d, device=self.device, dtype=dtype)
142-
task_indices = torch.tensor(
143-
[0.0, 0.0, 0.0, 0.0, 0.0], device=self.device, dtype=dtype
122+
test_x = train_x[:5]
123+
self.assertIsInstance(model(test_x), MultivariateNormal)
124+
125+
def test_construct_inputs(self) -> None:
126+
for with_embedding_inputs, yvar in ((True, None), (False, 0.01)):
127+
dataset, (train_x, train_y, train_yvar) = gen_multi_task_dataset(
128+
yvar=yvar, dtype=torch.double, device=self.device
144129
)
145-
test_x = torch.cat(
146-
[test_x, task_indices.unsqueeze(-1)],
147-
axis=1,
130+
model_inputs = LCEMGP.construct_inputs(
131+
training_data=dataset,
132+
task_feature=0,
133+
embs_dim_list=[2] if with_embedding_inputs else None,
134+
context_emb_feature=(
135+
torch.tensor([[0.2], [0.3]]) if with_embedding_inputs else None
136+
),
137+
context_cat_feature=(
138+
torch.tensor([[0.4], [0.5]]) if with_embedding_inputs else None
139+
),
148140
)
149-
self.assertIsInstance(model(test_x), MultivariateNormal)
141+
# Check that the model inputs are valid.
142+
LCEMGP(**model_inputs)
143+
# Check that the model inputs are as expected.
144+
self.assertAllClose(model_inputs.pop("train_X"), train_x)
145+
self.assertAllClose(model_inputs.pop("train_Y"), train_y)
146+
if yvar is not None:
147+
self.assertAllClose(model_inputs.pop("train_Yvar"), train_yvar)
148+
if with_embedding_inputs:
149+
self.assertEqual(model_inputs.pop("embs_dim_list"), [2])
150+
self.assertAllClose(
151+
model_inputs.pop("context_emb_feature"),
152+
torch.tensor([[0.2], [0.3]]),
153+
)
154+
self.assertAllClose(
155+
model_inputs.pop("context_cat_feature"),
156+
torch.tensor([[0.4], [0.5]]),
157+
)
158+
self.assertEqual(model_inputs.pop("all_tasks"), [0, 1])
159+
self.assertEqual(model_inputs.pop("task_feature"), 0)
160+
self.assertIsNone(model_inputs.pop("output_tasks"))
161+
# Check that there are no unexpected inputs.
162+
self.assertEqual(model_inputs, {})

0 commit comments

Comments
 (0)