Skip to content

Commit 3f8f20b

Browse files
Merge pull request NSLS-II#108 from thomaswmorris/fix-plots
Fix CI failures due to domain transforms
2 parents 0147758 + c0d2408 commit 3f8f20b

File tree

9 files changed

+368
-318
lines changed

9 files changed

+368
-318
lines changed

Diff for: src/blop/agent.py

+82-64
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from bluesky.run_engine import Msg
2020
from botorch.acquisition.acquisition import AcquisitionFunction # type: ignore[import-untyped]
2121
from botorch.acquisition.objective import ScalarizedPosteriorTransform # type: ignore[import-untyped]
22-
from botorch.models.deterministic import GenericDeterministicModel # type: ignore[import-untyped]
2322
from botorch.models.model import Model # type: ignore[import-untyped]
2423
from botorch.models.model_list_gp_regression import ModelListGP # type: ignore[import-untyped]
2524
from botorch.models.transforms.input import Normalize # type: ignore[import-untyped]
@@ -154,7 +153,11 @@ def raw_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.T
154153
"""
155154
if index is None:
156155
return torch.stack([self.raw_inputs(dof.name) for dof in self.dofs(**subset_kwargs)], dim=-1)
157-
return torch.tensor(self._table.loc[:, self.dofs[index].name].values, dtype=torch.double)
156+
157+
key = self.dofs[index].name
158+
if key in self._table.columns:
159+
return torch.tensor(self._table.loc[:, self.dofs[index].name].values, dtype=torch.double)
160+
return torch.ones(0)
158161

159162
def train_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
160163
"""
@@ -175,7 +178,9 @@ def raw_targets_dict(self, index: str | int | None = None, **subset_kwargs) -> d
175178
if index is None:
176179
return {obj.name: self.raw_targets_dict(obj.name)[obj.name] for obj in self.objectives(**subset_kwargs)}
177180
key = self.objectives[index].name
178-
return {key: torch.tensor(self._table.loc[:, key].values, dtype=torch.double)}
181+
if key in self._table.columns:
182+
return {key: torch.tensor(self._table.loc[:, key].values, dtype=torch.double)}
183+
return {key: torch.tensor([], dtype=torch.double)}
179184

180185
def raw_targets(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
181186
"""
@@ -281,7 +286,7 @@ def fitness_scalarization(self, weights: str | torch.Tensor = "default") -> Scal
281286
weights *= len(active_fitness_objectives) / weights.sum()
282287
elif not isinstance(weights, torch.Tensor):
283288
raise ValueError(f"'weights' must be a Tensor or one of ['default', 'equal', 'random'], and not {weights}.")
284-
return ScalarizedPosteriorTransform(weights=weights)
289+
return ScalarizedPosteriorTransform(weights=weights * active_fitness_objectives.signs)
285290

286291
@property
287292
def fitness_model(self) -> Model:
@@ -318,24 +323,29 @@ def sample(self, n: int = DEFAULT_MAX_SAMPLES, normalize: bool = False, method:
318323

319324
active_dofs = self.dofs(active=True)
320325

321-
if method == "quasi-random":
322-
X = utils.normalized_sobol_sampler(n, d=len(active_dofs))
323-
324-
elif method == "random":
325-
X = torch.rand(size=(n, 1, len(active_dofs)))
326-
327-
elif method == "grid":
326+
if method == "grid":
328327
read_only_tensor = cast(torch.Tensor, active_dofs.read_only)
329328
n_side_if_settable = int(np.power(n, 1 / torch.sum(~read_only_tensor)))
330-
sides = [
331-
torch.linspace(0, 1, n_side_if_settable) if not dof.read_only else torch.zeros(1) for dof in active_dofs
332-
]
333-
X = torch.cat([x.unsqueeze(-1) for x in torch.meshgrid(sides, indexing="ij")], dim=-1).unsqueeze(-2).double()
329+
grid_sides = []
330+
for dof in active_dofs:
331+
if dof.read_only:
332+
grid_sides.append(dof._transform(torch.tensor([dof.readback], dtype=torch.double)))
333+
else:
334+
grid_side_bins = torch.linspace(0, 1, n_side_if_settable + 1, dtype=torch.double)
335+
grid_sides.append((grid_side_bins[:-1] + grid_side_bins[1:]) / 2)
336+
337+
tX = torch.stack(torch.meshgrid(grid_sides, indexing="ij"), dim=-1).unsqueeze(-2)
338+
339+
elif method == "quasi-random":
340+
tX = utils.normalized_sobol_sampler(n, d=len(active_dofs))
341+
342+
elif method == "random":
343+
tX = torch.rand(size=(n, 1, len(active_dofs)))
334344

335345
else:
336346
raise ValueError("'method' argument must be one of ['quasi-random', 'random', 'grid'].")
337347

338-
return X.double() if normalize else self.dofs(active=True).untransform(X).double()
348+
return tX.double() if normalize else self.dofs.untransform(tX.double())
339349

340350
# @property
341351
def pruned_mask(self) -> torch.Tensor:
@@ -387,7 +397,6 @@ def _construct_model(self, obj, skew_dims: list[tuple[int, ...]] | None = None)
387397

388398
if trusted.all():
389399
obj.validity_conjugate_model = None
390-
obj.validity_constraint = GenericDeterministicModel(f=lambda x: torch.ones(size=x.size())[..., -1])
391400

392401
else:
393402
dirichlet_likelihood = gpytorch.likelihoods.DirichletClassificationLikelihood(
@@ -402,37 +411,44 @@ def _construct_model(self, obj, skew_dims: list[tuple[int, ...]] | None = None)
402411
input_transform=self.input_normalization,
403412
)
404413

405-
obj.validity_constraint = GenericDeterministicModel(
406-
f=lambda x: obj.validity_conjugate_model.probabilities(x)[..., -1]
407-
)
408-
409414
def update_models(
410415
self,
411-
train: bool | None = None,
416+
force_train: bool = False,
412417
) -> None:
418+
"""
419+
We don't want to retrain the models on every call of everything, but if they are out of sync with
420+
the DOFs then we should.
421+
"""
422+
423+
active_dofs = self.dofs(active=True)
413424
objectives_to_model = self.objectives if self.model_inactive_objectives else self.objectives(active=True)
425+
414426
for obj in objectives_to_model:
415-
t0 = ttime.monotonic()
427+
# do we need to update the model for this objective?
428+
n_trainable_points = sum(~self.train_targets(obj.name).isnan())
429+
430+
# if we don't have enough points
431+
if n_trainable_points < obj.min_points_to_train:
432+
continue
433+
434+
# if the current model matches the active dofs
435+
if getattr(obj, "model_dofs", {}) == set(active_dofs.names):
436+
# then we can use the current hyperparameters and just update the data
437+
cached_hypers = obj.model.state_dict() if obj.model else None
438+
439+
logger.debug(f'{getattr(obj, "model_dofs", {}) = }')
440+
logger.debug(f"{set(active_dofs.names) = }")
441+
# if there aren't enough extra points to train yet
442+
if n_trainable_points // self.train_every == len(obj.model.train_targets) // self.train_every:
443+
if not force_train:
444+
self._construct_model(obj)
445+
train_model(obj.model, hypers=cached_hypers)
446+
continue
416447

417-
cached_hypers = obj.model.state_dict() if obj.model else None
418-
n_before_tell = obj.n_valid
448+
t0 = ttime.monotonic()
419449
self._construct_model(obj)
420-
if not obj.model:
421-
raise RuntimeError(f"Expected {obj} to have a constructed model.")
422-
n_after_tell = obj.n_valid
423-
424-
if train is None:
425-
train = int(n_after_tell / self.train_every) > int(n_before_tell / self.train_every)
426-
427-
if len(obj.model.train_targets) >= 4:
428-
if train:
429-
t0 = ttime.monotonic()
430-
train_model(obj.model)
431-
if self.verbose:
432-
logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")
433-
434-
else:
435-
train_model(obj.model, hypers=cached_hypers)
450+
train_model(obj.model)
451+
logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")
436452

437453
def tell(
438454
self,
@@ -441,8 +457,7 @@ def tell(
441457
y: Mapping | None = {},
442458
metadata: Mapping | None = {},
443459
append: bool = True,
444-
update_models: bool = True,
445-
train: bool | None = None,
460+
force_train: bool = False,
446461
) -> None:
447462
"""
448463
Inform the agent about new inputs and targets for the model.
@@ -477,12 +492,12 @@ def tell(
477492
if len(unique_field_lengths) > 1:
478493
raise ValueError("All supplies values must be the same length!")
479494

480-
# TODO: This is an innefficient approach to caching data. Keep a list, make table at update model time.
495+
# TODO: This is an inefficient approach to caching data. Keep a list, make table at update model time.
481496
new_table = pd.DataFrame(data)
482497
self._table = pd.concat([self._table, new_table]) if append else new_table
483498
self._table.index = pd.Index(np.arange(len(self._table)))
484-
if update_models:
485-
self.update_models(train=train)
499+
500+
self.update_models(force_train=force_train)
486501

487502
def ask(
488503
self, acqf: str = "qei", n: int = 1, route: bool = True, sequential: bool = True, upsample: int = 1, **acqf_kwargs
@@ -525,15 +540,11 @@ def ask(
525540
f"Can't construct non-trivial acquisition function '{acqf}' as the agent is not initialized."
526541
)
527542

528-
# if the model for any active objective mismatches the active dofs, reconstrut and train it
529-
for obj in active_objs:
530-
if hasattr(obj, "model_dofs") and obj.model_dofs != set(active_dofs.names):
531-
self._construct_model(obj)
532-
train_model(obj.model)
533-
534543
if acqf_config["type"] == "analytic" and n > 1:
535544
raise ValueError("Can't generate multiple design points for analytic acquisition functions.")
536545

546+
self.update_models()
547+
537548
# we may pick up some more kwargs
538549
acqf, acqf_kwargs = _construct_acqf(self, acqf_name=acqf_config["name"], **acqf_kwargs)
539550

@@ -556,8 +567,6 @@ def ask(
556567
# and is in the transformed model space
557568
candidates = self.dofs(active=True).untransform(candidates)
558569

559-
# p = self.posterior(candidates) if hasattr(self, "model") else None
560-
561570
active_dofs = self.dofs(active=True)
562571

563572
read_only_tensor = cast(torch.Tensor, active_dofs.read_only)
@@ -714,7 +723,7 @@ def learn(
714723
n: int = 1,
715724
iterations: int = 1,
716725
upsample: int = 1,
717-
train: bool | None = None,
726+
force_train: bool = False,
718727
append: bool = True,
719728
hypers: str | None = None,
720729
route: bool = True,
@@ -768,7 +777,7 @@ def learn(
768777
metadata = {
769778
key: new_table.loc[:, key].tolist() for key in new_table.columns if (key not in x) and (key not in y)
770779
}
771-
self.tell(x=x, y=y, metadata=metadata, append=append, train=train)
780+
self.tell(x=x, y=y, metadata=metadata, append=append, force_train=force_train)
772781

773782
def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = 2**16):
774783
"""
@@ -1040,8 +1049,8 @@ def _set_hypers(self, hypers: dict[str, Any]):
10401049
if not obj.model:
10411050
raise RuntimeError(f"Expected {obj} to have a constructed model.")
10421051
obj.model.load_state_dict(hypers[obj.name])
1043-
if self.validity_constraint:
1044-
self.validity_constraint.load_state_dict(hypers["validity_constraint"])
1052+
if self.validity_probability:
1053+
self.validity_probability.load_state_dict(hypers["validity_probability"])
10451054

10461055
@property
10471056
def hypers(self) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
@@ -1139,13 +1148,18 @@ def plot_objectives(self, axes: tuple[int, int] = (0, 1), **kwargs) -> None:
11391148
axes :
11401149
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
11411150
"""
1142-
if len(self.dofs(active=True, read_only=False)) == 1:
1143-
if len(self.objectives(active=True, fitness=True)) > 0:
1144-
plotting._plot_fitness_objs_one_dof(self, **kwargs)
1145-
if len(self.objectives(active=True, constraint=True)) > 0:
1146-
plotting._plot_constraint_objs_one_dof(self, **kwargs)
1151+
self.update_models()
1152+
1153+
plottable_dofs = self.dofs(active=True, read_only=False)
1154+
logger.debug(f"Plotting agent with DOFs {self.dofs} and objectives {self.objectives}")
1155+
if len(plottable_dofs) == 0:
1156+
raise ValueError("To plot agent objectives, at least one writeable DOF must be active.")
1157+
elif len(plottable_dofs) == 1:
1158+
plotting._plot_objs_one_dof(self, **kwargs)
1159+
elif len(plottable_dofs) == 2:
1160+
plotting._plot_objs_many_dofs(self, gridded=True, axes=axes, **kwargs)
11471161
else:
1148-
plotting._plot_objs_many_dofs(self, axes=axes, **kwargs)
1162+
plotting._plot_objs_many_dofs(self, gridded=False, axes=axes, **kwargs)
11491163

11501164
def plot_acquisition(self, acqf: str = "ei", axes: tuple[int, int] = (0, 1), **kwargs) -> None:
11511165
"""Plot an acquisition function over test inputs sampling the limits of the parameter space.
@@ -1157,6 +1171,8 @@ def plot_acquisition(self, acqf: str = "ei", axes: tuple[int, int] = (0, 1), **k
11571171
axes :
11581172
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
11591173
"""
1174+
self.update_models()
1175+
11601176
if len(self.dofs(active=True, read_only=False)) == 1:
11611177
plotting._plot_acqf_one_dof(self, acqfs=np.atleast_1d(acqf), **kwargs)
11621178
else:
@@ -1170,6 +1186,8 @@ def plot_validity(self, axes: tuple[int, int] = (0, 1), **kwargs) -> None:
11701186
axes :
11711187
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
11721188
"""
1189+
self.update_models()
1190+
11731191
if len(self.dofs(active=True, read_only=False)) == 1:
11741192
plotting._plot_valid_one_dof(self, **kwargs)
11751193
else:

Diff for: src/blop/bayesian/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __init__(
129129

130130
self.trained: bool = False
131131

132-
def probabilities(self, x: torch.Tensor, n_samples: int = 1024) -> torch.Tensor:
132+
def probabilities(self, x: torch.Tensor, n_samples: int = 256) -> torch.Tensor:
133133
"""
134134
Takes in a (..., m) dimension tensor and returns a (..., n_classes) tensor
135135
"""

Diff for: src/blop/dofs.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -392,21 +392,25 @@ def transform(self, X: torch.Tensor) -> torch.Tensor:
392392
"""
393393
Transform X to the transformed unit hypercube.
394394
"""
395-
if X.shape[-1] != len(self):
396-
raise ValueError(f"Cannot transform points with shape {X.shape} using DOFs with dimension {len(self)}.")
395+
active_dofs = self(active=True)
396+
if X.shape[-1] != len(active_dofs):
397+
raise ValueError(
398+
f"Cannot transform points with shape {X.shape} using DOFs with active dimension {len(active_dofs)}."
399+
)
397400

398-
return torch.cat([dof._transform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(self.dofs)], dim=-1)
401+
return torch.cat([dof._transform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(active_dofs)], dim=-1)
399402

400403
def untransform(self, X: torch.Tensor) -> torch.Tensor:
401404
"""
402405
Transform the transformed unit hypercube to the search domain.
403406
"""
404-
if X.shape[-1] != len(self):
405-
raise ValueError(f"Cannot untransform points with shape {X.shape} using DOFs with dimension {len(self)}.")
407+
active_dofs = self(active=True)
408+
if X.shape[-1] != len(active_dofs):
409+
raise ValueError(
410+
f"Cannot untransform points with shape {X.shape} using DOFs with active dimension {len(active_dofs)}."
411+
)
406412

407-
return torch.cat(
408-
[dof._untransform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(self.subset(active=True))], dim=-1
409-
)
413+
return torch.cat([dof._untransform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(active_dofs)], dim=-1)
410414

411415
@property
412416
def readback(self) -> list[Any]:

0 commit comments

Comments
 (0)