Skip to content

Commit a78ef1b

Browse files
blethamstefanpricopie
authored andcommitted
Update contextual models for use in MBM (pytorch#2206)
Summary: X-link: facebook/Ax#2206 Pull Request resolved: pytorch#2206 Updates for contextual models to be used in MBM Reviewed By: saitcakmak Differential Revision: D52452235 fbshipit-source-id: 6fa479e1498789d9330cc7b07ac0500612bfd505
1 parent 485b195 commit a78ef1b

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

botorch/models/contextual.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
def construct_inputs(
122122
cls,
123123
training_data: SupervisedDataset,
124-
decomposition: Dict[str, List[int]],
124+
decomposition: Dict[str, List[str]],
125125
train_embedding: bool = True,
126126
cat_feature_dict: Optional[Dict] = None,
127127
embs_feature_dict: Optional[Dict] = None,
@@ -133,7 +133,7 @@ def construct_inputs(
133133
134134
Args:
135135
training_data: A `SupervisedDataset` containing the training data.
136-
decomposition: Dictionary of context names and their indexes of the
136+
decomposition: Dictionary of context names and the names of the
137137
corresponding active context parameters.
138138
train_embedding: Whether to train the embedding layer or not.
139139
cat_feature_dict: Keys are context names and values are list of categorical
@@ -148,9 +148,13 @@ def construct_inputs(
148148
context_weight_dict: Known population weights of each context.
149149
"""
150150
base_inputs = super().construct_inputs(training_data=training_data, **kwargs)
151+
index_decomp = {
152+
c: [training_data.feature_names.index(i) for i in v]
153+
for c, v in decomposition.items()
154+
}
151155
return {
152156
**base_inputs,
153-
"decomposition": decomposition,
157+
"decomposition": index_decomp,
154158
"train_embedding": train_embedding,
155159
"cat_feature_dict": cat_feature_dict,
156160
"embs_feature_dict": embs_feature_dict,

test/models/test_contextual.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ def test_LCEAGP_construct_inputs(self):
135135
for dtype in (torch.float, torch.double):
136136
tkwargs = {"device": self.device, "dtype": dtype}
137137
datasets, (train_X, train_Y, train_Yvar) = _gen_datasets(**tkwargs)
138-
decomposition = {"1": [0, 1], "2": [2, 3]}
138+
decomposition = {"1": ["x0", "x1"], "2": ["x2", "x3"]}
139+
decomposition_index = {"1": [0, 1], "2": [2, 3]}
139140

140-
model = LCEAGP(train_X, train_Y, train_Yvar, decomposition)
141-
data_dict = model.construct_inputs(
141+
data_dict = LCEAGP.construct_inputs(
142142
training_data=datasets,
143143
decomposition=decomposition,
144144
train_embedding=False,
@@ -147,5 +147,5 @@ def test_LCEAGP_construct_inputs(self):
147147
self.assertTrue(train_X.equal(data_dict["train_X"]))
148148
self.assertTrue(train_Y.equal(data_dict["train_Y"]))
149149
self.assertTrue(train_Yvar.equal(data_dict["train_Yvar"]))
150-
self.assertDictEqual(data_dict["decomposition"], decomposition)
150+
self.assertDictEqual(data_dict["decomposition"], decomposition_index)
151151
self.assertFalse(data_dict["train_embedding"])

0 commit comments

Comments
 (0)