19
19
from bluesky .run_engine import Msg
20
20
from botorch .acquisition .acquisition import AcquisitionFunction # type: ignore[import-untyped]
21
21
from botorch .acquisition .objective import ScalarizedPosteriorTransform # type: ignore[import-untyped]
22
- from botorch .models .deterministic import GenericDeterministicModel # type: ignore[import-untyped]
23
22
from botorch .models .model import Model # type: ignore[import-untyped]
24
23
from botorch .models .model_list_gp_regression import ModelListGP # type: ignore[import-untyped]
25
24
from botorch .models .transforms .input import Normalize # type: ignore[import-untyped]
@@ -154,7 +153,11 @@ def raw_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.T
154
153
"""
155
154
if index is None :
156
155
return torch .stack ([self .raw_inputs (dof .name ) for dof in self .dofs (** subset_kwargs )], dim = - 1 )
157
- return torch .tensor (self ._table .loc [:, self .dofs [index ].name ].values , dtype = torch .double )
156
+
157
+ key = self .dofs [index ].name
158
+ if key in self ._table .columns :
159
+ return torch .tensor (self ._table .loc [:, self .dofs [index ].name ].values , dtype = torch .double )
160
+ return torch .ones (0 )
158
161
159
162
def train_inputs (self , index : str | int | None = None , ** subset_kwargs ) -> torch .Tensor :
160
163
"""
@@ -175,7 +178,9 @@ def raw_targets_dict(self, index: str | int | None = None, **subset_kwargs) -> d
175
178
if index is None :
176
179
return {obj .name : self .raw_targets_dict (obj .name )[obj .name ] for obj in self .objectives (** subset_kwargs )}
177
180
key = self .objectives [index ].name
178
- return {key : torch .tensor (self ._table .loc [:, key ].values , dtype = torch .double )}
181
+ if key in self ._table .columns :
182
+ return {key : torch .tensor (self ._table .loc [:, key ].values , dtype = torch .double )}
183
+ return {key : torch .tensor ([], dtype = torch .double )}
179
184
180
185
def raw_targets (self , index : str | int | None = None , ** subset_kwargs ) -> torch .Tensor :
181
186
"""
@@ -281,7 +286,7 @@ def fitness_scalarization(self, weights: str | torch.Tensor = "default") -> Scal
281
286
weights *= len (active_fitness_objectives ) / weights .sum ()
282
287
elif not isinstance (weights , torch .Tensor ):
283
288
raise ValueError (f"'weights' must be a Tensor or one of ['default', 'equal', 'random'], and not { weights } ." )
284
- return ScalarizedPosteriorTransform (weights = weights )
289
+ return ScalarizedPosteriorTransform (weights = weights * active_fitness_objectives . signs )
285
290
286
291
@property
287
292
def fitness_model (self ) -> Model :
@@ -318,24 +323,29 @@ def sample(self, n: int = DEFAULT_MAX_SAMPLES, normalize: bool = False, method:
318
323
319
324
active_dofs = self .dofs (active = True )
320
325
321
- if method == "quasi-random" :
322
- X = utils .normalized_sobol_sampler (n , d = len (active_dofs ))
323
-
324
- elif method == "random" :
325
- X = torch .rand (size = (n , 1 , len (active_dofs )))
326
-
327
- elif method == "grid" :
326
+ if method == "grid" :
328
327
read_only_tensor = cast (torch .Tensor , active_dofs .read_only )
329
328
n_side_if_settable = int (np .power (n , 1 / torch .sum (~ read_only_tensor )))
330
- sides = [
331
- torch .linspace (0 , 1 , n_side_if_settable ) if not dof .read_only else torch .zeros (1 ) for dof in active_dofs
332
- ]
333
- X = torch .cat ([x .unsqueeze (- 1 ) for x in torch .meshgrid (sides , indexing = "ij" )], dim = - 1 ).unsqueeze (- 2 ).double ()
329
+ grid_sides = []
330
+ for dof in active_dofs :
331
+ if dof .read_only :
332
+ grid_sides .append (dof ._transform (torch .tensor ([dof .readback ], dtype = torch .double )))
333
+ else :
334
+ grid_side_bins = torch .linspace (0 , 1 , n_side_if_settable + 1 , dtype = torch .double )
335
+ grid_sides .append ((grid_side_bins [:- 1 ] + grid_side_bins [1 :]) / 2 )
336
+
337
+ tX = torch .stack (torch .meshgrid (grid_sides , indexing = "ij" ), dim = - 1 ).unsqueeze (- 2 )
338
+
339
+ elif method == "quasi-random" :
340
+ tX = utils .normalized_sobol_sampler (n , d = len (active_dofs ))
341
+
342
+ elif method == "random" :
343
+ tX = torch .rand (size = (n , 1 , len (active_dofs )))
334
344
335
345
else :
336
346
raise ValueError ("'method' argument must be one of ['quasi-random', 'random', 'grid']." )
337
347
338
- return X .double () if normalize else self .dofs ( active = True ) .untransform (X ) .double ()
348
+ return tX .double () if normalize else self .dofs .untransform (tX .double () )
339
349
340
350
# @property
341
351
def pruned_mask (self ) -> torch .Tensor :
@@ -387,7 +397,6 @@ def _construct_model(self, obj, skew_dims: list[tuple[int, ...]] | None = None)
387
397
388
398
if trusted .all ():
389
399
obj .validity_conjugate_model = None
390
- obj .validity_constraint = GenericDeterministicModel (f = lambda x : torch .ones (size = x .size ())[..., - 1 ])
391
400
392
401
else :
393
402
dirichlet_likelihood = gpytorch .likelihoods .DirichletClassificationLikelihood (
@@ -402,37 +411,44 @@ def _construct_model(self, obj, skew_dims: list[tuple[int, ...]] | None = None)
402
411
input_transform = self .input_normalization ,
403
412
)
404
413
405
- obj .validity_constraint = GenericDeterministicModel (
406
- f = lambda x : obj .validity_conjugate_model .probabilities (x )[..., - 1 ]
407
- )
408
-
409
414
def update_models (
410
415
self ,
411
- train : bool | None = None ,
416
+ force_train : bool = False ,
412
417
) -> None :
418
+ """
419
+ We don't want to retrain the models on every call of everything, but if they are out of sync with
420
+ the DOFs then we should.
421
+ """
422
+
423
+ active_dofs = self .dofs (active = True )
413
424
objectives_to_model = self .objectives if self .model_inactive_objectives else self .objectives (active = True )
425
+
414
426
for obj in objectives_to_model :
415
- t0 = ttime .monotonic ()
427
+ # do we need to update the model for this objective?
428
+ n_trainable_points = sum (~ self .train_targets (obj .name ).isnan ())
429
+
430
+ # if we don't have enough points
431
+ if n_trainable_points < obj .min_points_to_train :
432
+ continue
433
+
434
+ # if the current model matches the active dofs
435
+ if getattr (obj , "model_dofs" , {}) == set (active_dofs .names ):
436
+ # then we can use the current hyperparameters and just update the data
437
+ cached_hypers = obj .model .state_dict () if obj .model else None
438
+
439
+ logger .debug (f'{ getattr (obj , "model_dofs" , {}) = } ' )
440
+ logger .debug (f"{ set (active_dofs .names ) = } " )
441
+ # if there aren't enough extra points to train yet
442
+ if n_trainable_points // self .train_every == len (obj .model .train_targets ) // self .train_every :
443
+ if not force_train :
444
+ self ._construct_model (obj )
445
+ train_model (obj .model , hypers = cached_hypers )
446
+ continue
416
447
417
- cached_hypers = obj .model .state_dict () if obj .model else None
418
- n_before_tell = obj .n_valid
448
+ t0 = ttime .monotonic ()
419
449
self ._construct_model (obj )
420
- if not obj .model :
421
- raise RuntimeError (f"Expected { obj } to have a constructed model." )
422
- n_after_tell = obj .n_valid
423
-
424
- if train is None :
425
- train = int (n_after_tell / self .train_every ) > int (n_before_tell / self .train_every )
426
-
427
- if len (obj .model .train_targets ) >= 4 :
428
- if train :
429
- t0 = ttime .monotonic ()
430
- train_model (obj .model )
431
- if self .verbose :
432
- logger .debug (f"trained model '{ obj .name } ' in { 1e3 * (ttime .monotonic () - t0 ):.00f} ms" )
433
-
434
- else :
435
- train_model (obj .model , hypers = cached_hypers )
450
+ train_model (obj .model )
451
+ logger .debug (f"trained model '{ obj .name } ' in { 1e3 * (ttime .monotonic () - t0 ):.00f} ms" )
436
452
437
453
def tell (
438
454
self ,
@@ -441,8 +457,7 @@ def tell(
441
457
y : Mapping | None = {},
442
458
metadata : Mapping | None = {},
443
459
append : bool = True ,
444
- update_models : bool = True ,
445
- train : bool | None = None ,
460
+ force_train : bool = False ,
446
461
) -> None :
447
462
"""
448
463
Inform the agent about new inputs and targets for the model.
@@ -477,12 +492,12 @@ def tell(
477
492
if len (unique_field_lengths ) > 1 :
478
493
raise ValueError ("All supplies values must be the same length!" )
479
494
480
- # TODO: This is an innefficient approach to caching data. Keep a list, make table at update model time.
495
+ # TODO: This is an inefficient approach to caching data. Keep a list, make table at update model time.
481
496
new_table = pd .DataFrame (data )
482
497
self ._table = pd .concat ([self ._table , new_table ]) if append else new_table
483
498
self ._table .index = pd .Index (np .arange (len (self ._table )))
484
- if update_models :
485
- self .update_models (train = train )
499
+
500
+ self .update_models (force_train = force_train )
486
501
487
502
def ask (
488
503
self , acqf : str = "qei" , n : int = 1 , route : bool = True , sequential : bool = True , upsample : int = 1 , ** acqf_kwargs
@@ -525,15 +540,11 @@ def ask(
525
540
f"Can't construct non-trivial acquisition function '{ acqf } ' as the agent is not initialized."
526
541
)
527
542
528
- # if the model for any active objective mismatches the active dofs, reconstrut and train it
529
- for obj in active_objs :
530
- if hasattr (obj , "model_dofs" ) and obj .model_dofs != set (active_dofs .names ):
531
- self ._construct_model (obj )
532
- train_model (obj .model )
533
-
534
543
if acqf_config ["type" ] == "analytic" and n > 1 :
535
544
raise ValueError ("Can't generate multiple design points for analytic acquisition functions." )
536
545
546
+ self .update_models ()
547
+
537
548
# we may pick up some more kwargs
538
549
acqf , acqf_kwargs = _construct_acqf (self , acqf_name = acqf_config ["name" ], ** acqf_kwargs )
539
550
@@ -556,8 +567,6 @@ def ask(
556
567
# and is in the transformed model space
557
568
candidates = self .dofs (active = True ).untransform (candidates )
558
569
559
- # p = self.posterior(candidates) if hasattr(self, "model") else None
560
-
561
570
active_dofs = self .dofs (active = True )
562
571
563
572
read_only_tensor = cast (torch .Tensor , active_dofs .read_only )
@@ -714,7 +723,7 @@ def learn(
714
723
n : int = 1 ,
715
724
iterations : int = 1 ,
716
725
upsample : int = 1 ,
717
- train : bool | None = None ,
726
+ force_train : bool = False ,
718
727
append : bool = True ,
719
728
hypers : str | None = None ,
720
729
route : bool = True ,
@@ -768,7 +777,7 @@ def learn(
768
777
metadata = {
769
778
key : new_table .loc [:, key ].tolist () for key in new_table .columns if (key not in x ) and (key not in y )
770
779
}
771
- self .tell (x = x , y = y , metadata = metadata , append = append , train = train )
780
+ self .tell (x = x , y = y , metadata = metadata , append = append , force_train = force_train )
772
781
773
782
def view (self , item : str = "mean" , cmap : str = "turbo" , max_inputs : int = 2 ** 16 ):
774
783
"""
@@ -1040,8 +1049,8 @@ def _set_hypers(self, hypers: dict[str, Any]):
1040
1049
if not obj .model :
1041
1050
raise RuntimeError (f"Expected { obj } to have a constructed model." )
1042
1051
obj .model .load_state_dict (hypers [obj .name ])
1043
- if self .validity_constraint :
1044
- self .validity_constraint .load_state_dict (hypers ["validity_constraint " ])
1052
+ if self .validity_probability :
1053
+ self .validity_probability .load_state_dict (hypers ["validity_probability " ])
1045
1054
1046
1055
@property
1047
1056
def hypers (self ) -> dict [str , dict [str , dict [str , torch .Tensor ]]]:
@@ -1139,13 +1148,18 @@ def plot_objectives(self, axes: tuple[int, int] = (0, 1), **kwargs) -> None:
1139
1148
axes :
1140
1149
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
1141
1150
"""
1142
- if len (self .dofs (active = True , read_only = False )) == 1 :
1143
- if len (self .objectives (active = True , fitness = True )) > 0 :
1144
- plotting ._plot_fitness_objs_one_dof (self , ** kwargs )
1145
- if len (self .objectives (active = True , constraint = True )) > 0 :
1146
- plotting ._plot_constraint_objs_one_dof (self , ** kwargs )
1151
+ self .update_models ()
1152
+
1153
+ plottable_dofs = self .dofs (active = True , read_only = False )
1154
+ logger .debug (f"Plotting agent with DOFs { self .dofs } and objectives { self .objectives } " )
1155
+ if len (plottable_dofs ) == 0 :
1156
+ raise ValueError ("To plot agent objectives, at least one writeable DOF must be active." )
1157
+ elif len (plottable_dofs ) == 1 :
1158
+ plotting ._plot_objs_one_dof (self , ** kwargs )
1159
+ elif len (plottable_dofs ) == 2 :
1160
+ plotting ._plot_objs_many_dofs (self , gridded = True , axes = axes , ** kwargs )
1147
1161
else :
1148
- plotting ._plot_objs_many_dofs (self , axes = axes , ** kwargs )
1162
+ plotting ._plot_objs_many_dofs (self , gridded = False , axes = axes , ** kwargs )
1149
1163
1150
1164
def plot_acquisition (self , acqf : str = "ei" , axes : tuple [int , int ] = (0 , 1 ), ** kwargs ) -> None :
1151
1165
"""Plot an acquisition function over test inputs sampling the limits of the parameter space.
@@ -1157,6 +1171,8 @@ def plot_acquisition(self, acqf: str = "ei", axes: tuple[int, int] = (0, 1), **k
1157
1171
axes :
1158
1172
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
1159
1173
"""
1174
+ self .update_models ()
1175
+
1160
1176
if len (self .dofs (active = True , read_only = False )) == 1 :
1161
1177
plotting ._plot_acqf_one_dof (self , acqfs = np .atleast_1d (acqf ), ** kwargs )
1162
1178
else :
@@ -1170,6 +1186,8 @@ def plot_validity(self, axes: tuple[int, int] = (0, 1), **kwargs) -> None:
1170
1186
axes :
1171
1187
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
1172
1188
"""
1189
+ self .update_models ()
1190
+
1173
1191
if len (self .dofs (active = True , read_only = False )) == 1 :
1174
1192
plotting ._plot_valid_one_dof (self , ** kwargs )
1175
1193
else :
0 commit comments