10
10
from botorch .models .contextual_multioutput import FixedNoiseLCEMGP , LCEMGP
11
11
from botorch .models .multitask import MultiTaskGP
12
12
from botorch .posteriors import GPyTorchPosterior
13
+ from botorch .utils .test_helpers import gen_multi_task_dataset
13
14
from botorch .utils .testing import BotorchTestCase
14
15
from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
15
16
from gpytorch .mlls .exact_marginal_log_likelihood import ExactMarginalLogLikelihood
22
23
23
24
class ContextualMultiOutputTest (BotorchTestCase ):
24
25
def test_LCEMGP (self ):
25
- d = 1
26
26
for dtype , fixed_noise in ((torch .float , True ), (torch .double , False )):
27
- # test with batch evaluation
28
- train_x = torch .rand (10 , d , device = self .device , dtype = dtype )
29
- train_y = torch .cos (train_x )
30
- # 2 contexts here
31
- task_indices = torch .tensor (
32
- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ],
33
- device = self .device ,
34
- dtype = dtype ,
27
+ _ , (train_x , train_y , train_yvar ) = gen_multi_task_dataset (
28
+ yvar = 0.01 if fixed_noise else None , dtype = dtype , device = self .device
35
29
)
36
- train_x = torch .cat ([train_x , task_indices .unsqueeze (- 1 )], axis = 1 )
37
-
38
- if fixed_noise :
39
- train_yvar = torch .ones (10 , 1 , device = self .device , dtype = dtype ) * 0.01
40
- else :
41
- train_yvar = None
30
+ task_feature = 0
42
31
model = LCEMGP (
43
32
train_X = train_x ,
44
33
train_Y = train_y ,
45
- task_feature = d ,
34
+ task_feature = task_feature ,
46
35
train_Yvar = train_yvar ,
47
36
)
48
37
@@ -65,20 +54,18 @@ def test_LCEMGP(self):
65
54
self .assertIsInstance (embeddings , Tensor )
66
55
self .assertEqual (embeddings .shape , torch .Size ([2 , 1 ]))
67
56
68
- test_x = torch .rand (5 , d , device = self .device , dtype = dtype )
69
- task_indices = torch .tensor (
70
- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 ], device = self .device , dtype = dtype
71
- )
72
- test_x = torch .cat ([test_x , task_indices .unsqueeze (- 1 )], axis = 1 )
57
+ test_x = train_x [:5 ]
73
58
self .assertIsInstance (model (test_x ), MultivariateNormal )
74
59
75
60
# test posterior
76
- posterior_f = model .posterior (test_x [:, : d ])
61
+ posterior_f = model .posterior (test_x [:, task_feature + 1 : ])
77
62
self .assertIsInstance (posterior_f , GPyTorchPosterior )
78
63
self .assertIsInstance (posterior_f .distribution , MultitaskMultivariateNormal )
79
64
80
65
# test posterior w/ single output index
81
- posterior_f = model .posterior (test_x [:, :d ], output_indices = [0 ])
66
+ posterior_f = model .posterior (
67
+ test_x [:, task_feature + 1 :], output_indices = [0 ]
68
+ )
82
69
self .assertIsInstance (posterior_f , GPyTorchPosterior )
83
70
self .assertIsInstance (posterior_f .distribution , MultivariateNormal )
84
71
@@ -87,9 +74,9 @@ def test_LCEMGP(self):
87
74
model2 = LCEMGP (
88
75
train_X = train_x ,
89
76
train_Y = train_y ,
90
- task_feature = d ,
77
+ task_feature = task_feature ,
91
78
embs_dim_list = [2 ], # increase dim from 1 to 2
92
- context_emb_feature = torch .Tensor ([[0.2 ], [0.3 ]]),
79
+ context_emb_feature = torch .tensor ([[0.2 ], [0.3 ]]),
93
80
)
94
81
self .assertIsInstance (model2 , LCEMGP )
95
82
self .assertIsInstance (model2 , MultiTaskGP )
@@ -113,37 +100,63 @@ def test_LCEMGP(self):
113
100
left_interp_indices = task_idcs ,
114
101
right_interp_indices = task_idcs ,
115
102
).to_dense ()
116
- self .assertAllClose (previous_covar , model .task_covar_matrix (task_idcs ))
103
+ self .assertAllClose (previous_covar , model .task_covar_module (task_idcs ))
117
104
118
105
def test_FixedNoiseLCEMGP (self ):
119
- d = 1
120
106
for dtype in (torch .float , torch .double ):
121
- train_x = torch .rand (10 , d , device = self .device , dtype = dtype )
122
- train_y = torch .cos (train_x )
123
- task_indices = torch .tensor (
124
- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 ], device = self .device
107
+ _ , (train_x , train_y , train_yvar ) = gen_multi_task_dataset (
108
+ yvar = 0.01 , dtype = dtype , device = self .device
125
109
)
126
- train_x = torch .cat ([train_x , task_indices .unsqueeze (- 1 )], axis = 1 )
127
- train_yvar = torch .ones (10 , 1 , device = self .device , dtype = dtype ) * 0.01
128
110
129
111
with self .assertWarnsRegex (DeprecationWarning , "FixedNoiseLCEMGP" ):
130
112
model = FixedNoiseLCEMGP (
131
113
train_X = train_x ,
132
114
train_Y = train_y ,
133
115
train_Yvar = train_yvar ,
134
- task_feature = d ,
116
+ task_feature = 0 ,
135
117
)
136
118
mll = ExactMarginalLogLikelihood (model .likelihood , model )
137
119
fit_gpytorch_mll (mll , optimizer_kwargs = {"options" : {"maxiter" : 1 }})
138
-
139
120
self .assertIsInstance (model , FixedNoiseLCEMGP )
140
121
141
- test_x = torch .rand (5 , d , device = self .device , dtype = dtype )
142
- task_indices = torch .tensor (
143
- [0.0 , 0.0 , 0.0 , 0.0 , 0.0 ], device = self .device , dtype = dtype
122
+ test_x = train_x [:5 ]
123
+ self .assertIsInstance (model (test_x ), MultivariateNormal )
124
+
125
+ def test_construct_inputs (self ) -> None :
126
+ for with_embedding_inputs , yvar in ((True , None ), (False , 0.01 )):
127
+ dataset , (train_x , train_y , train_yvar ) = gen_multi_task_dataset (
128
+ yvar = yvar , dtype = torch .double , device = self .device
144
129
)
145
- test_x = torch .cat (
146
- [test_x , task_indices .unsqueeze (- 1 )],
147
- axis = 1 ,
130
+ model_inputs = LCEMGP .construct_inputs (
131
+ training_data = dataset ,
132
+ task_feature = 0 ,
133
+ embs_dim_list = [2 ] if with_embedding_inputs else None ,
134
+ context_emb_feature = (
135
+ torch .tensor ([[0.2 ], [0.3 ]]) if with_embedding_inputs else None
136
+ ),
137
+ context_cat_feature = (
138
+ torch .tensor ([[0.4 ], [0.5 ]]) if with_embedding_inputs else None
139
+ ),
148
140
)
149
- self .assertIsInstance (model (test_x ), MultivariateNormal )
141
+ # Check that the model inputs are valid.
142
+ LCEMGP (** model_inputs )
143
+ # Check that the model inputs are as expected.
144
+ self .assertAllClose (model_inputs .pop ("train_X" ), train_x )
145
+ self .assertAllClose (model_inputs .pop ("train_Y" ), train_y )
146
+ if yvar is not None :
147
+ self .assertAllClose (model_inputs .pop ("train_Yvar" ), train_yvar )
148
+ if with_embedding_inputs :
149
+ self .assertEqual (model_inputs .pop ("embs_dim_list" ), [2 ])
150
+ self .assertAllClose (
151
+ model_inputs .pop ("context_emb_feature" ),
152
+ torch .tensor ([[0.2 ], [0.3 ]]),
153
+ )
154
+ self .assertAllClose (
155
+ model_inputs .pop ("context_cat_feature" ),
156
+ torch .tensor ([[0.4 ], [0.5 ]]),
157
+ )
158
+ self .assertEqual (model_inputs .pop ("all_tasks" ), [0 , 1 ])
159
+ self .assertEqual (model_inputs .pop ("task_feature" ), 0 )
160
+ self .assertIsNone (model_inputs .pop ("output_tasks" ))
161
+ # Check that there are no unexpected inputs.
162
+ self .assertEqual (model_inputs , {})
0 commit comments