Skip to content

Fix CI failures due to domain transforms #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 82 additions & 64 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from bluesky.run_engine import Msg
from botorch.acquisition.acquisition import AcquisitionFunction # type: ignore[import-untyped]
from botorch.acquisition.objective import ScalarizedPosteriorTransform # type: ignore[import-untyped]
from botorch.models.deterministic import GenericDeterministicModel # type: ignore[import-untyped]
from botorch.models.model import Model # type: ignore[import-untyped]
from botorch.models.model_list_gp_regression import ModelListGP # type: ignore[import-untyped]
from botorch.models.transforms.input import Normalize # type: ignore[import-untyped]
Expand Down Expand Up @@ -154,7 +153,11 @@ def raw_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.T
"""
if index is None:
return torch.stack([self.raw_inputs(dof.name) for dof in self.dofs(**subset_kwargs)], dim=-1)
return torch.tensor(self._table.loc[:, self.dofs[index].name].values, dtype=torch.double)

key = self.dofs[index].name
if key in self._table.columns:
return torch.tensor(self._table.loc[:, self.dofs[index].name].values, dtype=torch.double)
return torch.ones(0)

def train_inputs(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
"""
Expand All @@ -175,7 +178,9 @@ def raw_targets_dict(self, index: str | int | None = None, **subset_kwargs) -> d
if index is None:
return {obj.name: self.raw_targets_dict(obj.name)[obj.name] for obj in self.objectives(**subset_kwargs)}
key = self.objectives[index].name
return {key: torch.tensor(self._table.loc[:, key].values, dtype=torch.double)}
if key in self._table.columns:
return {key: torch.tensor(self._table.loc[:, key].values, dtype=torch.double)}
return {key: torch.tensor([], dtype=torch.double)}

def raw_targets(self, index: str | int | None = None, **subset_kwargs) -> torch.Tensor:
"""
Expand Down Expand Up @@ -281,7 +286,7 @@ def fitness_scalarization(self, weights: str | torch.Tensor = "default") -> Scal
weights *= len(active_fitness_objectives) / weights.sum()
elif not isinstance(weights, torch.Tensor):
raise ValueError(f"'weights' must be a Tensor or one of ['default', 'equal', 'random'], and not {weights}.")
return ScalarizedPosteriorTransform(weights=weights)
return ScalarizedPosteriorTransform(weights=weights * active_fitness_objectives.signs)

@property
def fitness_model(self) -> Model:
Expand Down Expand Up @@ -318,24 +323,29 @@ def sample(self, n: int = DEFAULT_MAX_SAMPLES, normalize: bool = False, method:

active_dofs = self.dofs(active=True)

if method == "quasi-random":
X = utils.normalized_sobol_sampler(n, d=len(active_dofs))

elif method == "random":
X = torch.rand(size=(n, 1, len(active_dofs)))

elif method == "grid":
if method == "grid":
read_only_tensor = cast(torch.Tensor, active_dofs.read_only)
n_side_if_settable = int(np.power(n, 1 / torch.sum(~read_only_tensor)))
sides = [
torch.linspace(0, 1, n_side_if_settable) if not dof.read_only else torch.zeros(1) for dof in active_dofs
]
X = torch.cat([x.unsqueeze(-1) for x in torch.meshgrid(sides, indexing="ij")], dim=-1).unsqueeze(-2).double()
grid_sides = []
for dof in active_dofs:
if dof.read_only:
grid_sides.append(dof._transform(torch.tensor([dof.readback], dtype=torch.double)))
else:
grid_side_bins = torch.linspace(0, 1, n_side_if_settable + 1, dtype=torch.double)
grid_sides.append((grid_side_bins[:-1] + grid_side_bins[1:]) / 2)

tX = torch.stack(torch.meshgrid(grid_sides, indexing="ij"), dim=-1).unsqueeze(-2)

elif method == "quasi-random":
tX = utils.normalized_sobol_sampler(n, d=len(active_dofs))

elif method == "random":
tX = torch.rand(size=(n, 1, len(active_dofs)))

else:
raise ValueError("'method' argument must be one of ['quasi-random', 'random', 'grid'].")

return X.double() if normalize else self.dofs(active=True).untransform(X).double()
return tX.double() if normalize else self.dofs.untransform(tX.double())

# @property
def pruned_mask(self) -> torch.Tensor:
Expand Down Expand Up @@ -387,7 +397,6 @@ def _construct_model(self, obj, skew_dims: list[tuple[int, ...]] | None = None)

if trusted.all():
obj.validity_conjugate_model = None
obj.validity_constraint = GenericDeterministicModel(f=lambda x: torch.ones(size=x.size())[..., -1])

else:
dirichlet_likelihood = gpytorch.likelihoods.DirichletClassificationLikelihood(
Expand All @@ -402,37 +411,44 @@ def _construct_model(self, obj, skew_dims: list[tuple[int, ...]] | None = None)
input_transform=self.input_normalization,
)

obj.validity_constraint = GenericDeterministicModel(
f=lambda x: obj.validity_conjugate_model.probabilities(x)[..., -1]
)

def update_models(
self,
train: bool | None = None,
force_train: bool = False,
) -> None:
"""
We don't want to retrain the models on every call of everything, but if they are out of sync with
the DOFs then we should.
"""
Comment on lines +418 to +421
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it clear the behavior of this parameter.

Suggested change
"""
We don't want to retrain the models on every call of everything, but if they are out of sync with
the DOFs then we should.
"""
"""
We don't want to retrain the models on every call of everything, but if they are out of sync with
the DOFs then we should.
Parameters
---------------
force_train : bool
Force re-training the model. If false, the model will skip re-training unless there are `self.train_every` more data points to train with.
"""


active_dofs = self.dofs(active=True)
objectives_to_model = self.objectives if self.model_inactive_objectives else self.objectives(active=True)

for obj in objectives_to_model:
t0 = ttime.monotonic()
# do we need to update the model for this objective?
n_trainable_points = sum(~self.train_targets(obj.name).isnan())

# if we don't have enough points
if n_trainable_points < obj.min_points_to_train:
continue

# if the current model matches the active dofs
if getattr(obj, "model_dofs", {}) == set(active_dofs.names):
# then we can use the current hyperparameters and just update the data
cached_hypers = obj.model.state_dict() if obj.model else None

logger.debug(f'{getattr(obj, "model_dofs", {}) = }')
logger.debug(f"{set(active_dofs.names) = }")
# if there aren't enough extra points to train yet
if n_trainable_points // self.train_every == len(obj.model.train_targets) // self.train_every:
if not force_train:
Comment on lines +441 to +443
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my understanding of the logic but it took me some time to get there. Adding a bit more explanation could be helpful to newer contributors.

Suggested change
# if there aren't enough extra points to train yet
if n_trainable_points // self.train_every == len(obj.model.train_targets) // self.train_every:
if not force_train:
# compare the new number of trainable points with the old number of train targets
# if there aren't enough points yet and we do not force training, we skip building a model for this objective
if not force_train and n_trainable_points // self.train_every == len(obj.model.train_targets) // self.train_every:

self._construct_model(obj)
train_model(obj.model, hypers=cached_hypers)
Copy link
Contributor

@thomashopkins32 thomashopkins32 Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dual use of train_model here is confusing to me. I think we should have a load_model method separate from train_model that takes hypers.

When you pass in hypers to train_model here it's not actually training the model.

continue

cached_hypers = obj.model.state_dict() if obj.model else None
n_before_tell = obj.n_valid
t0 = ttime.monotonic()
self._construct_model(obj)
if not obj.model:
raise RuntimeError(f"Expected {obj} to have a constructed model.")
n_after_tell = obj.n_valid

if train is None:
train = int(n_after_tell / self.train_every) > int(n_before_tell / self.train_every)

if len(obj.model.train_targets) >= 4:
if train:
t0 = ttime.monotonic()
train_model(obj.model)
if self.verbose:
logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")

else:
train_model(obj.model, hypers=cached_hypers)
train_model(obj.model)
logger.debug(f"trained model '{obj.name}' in {1e3 * (ttime.monotonic() - t0):.00f} ms")

def tell(
self,
Expand All @@ -441,8 +457,7 @@ def tell(
y: Mapping | None = {},
metadata: Mapping | None = {},
append: bool = True,
update_models: bool = True,
train: bool | None = None,
force_train: bool = False,
) -> None:
"""
Inform the agent about new inputs and targets for the model.
Expand Down Expand Up @@ -477,12 +492,12 @@ def tell(
if len(unique_field_lengths) > 1:
raise ValueError("All supplies values must be the same length!")

# TODO: This is an innefficient approach to caching data. Keep a list, make table at update model time.
# TODO: This is an inefficient approach to caching data. Keep a list, make table at update model time.
new_table = pd.DataFrame(data)
self._table = pd.concat([self._table, new_table]) if append else new_table
self._table.index = pd.Index(np.arange(len(self._table)))
if update_models:
self.update_models(train=train)

self.update_models(force_train=force_train)

def ask(
self, acqf: str = "qei", n: int = 1, route: bool = True, sequential: bool = True, upsample: int = 1, **acqf_kwargs
Expand Down Expand Up @@ -525,15 +540,11 @@ def ask(
f"Can't construct non-trivial acquisition function '{acqf}' as the agent is not initialized."
)

# if the model for any active objective mismatches the active dofs, reconstrut and train it
for obj in active_objs:
if hasattr(obj, "model_dofs") and obj.model_dofs != set(active_dofs.names):
self._construct_model(obj)
train_model(obj.model)

if acqf_config["type"] == "analytic" and n > 1:
raise ValueError("Can't generate multiple design points for analytic acquisition functions.")

self.update_models()

# we may pick up some more kwargs
acqf, acqf_kwargs = _construct_acqf(self, acqf_name=acqf_config["name"], **acqf_kwargs)

Expand All @@ -556,8 +567,6 @@ def ask(
# and is in the transformed model space
candidates = self.dofs(active=True).untransform(candidates)

# p = self.posterior(candidates) if hasattr(self, "model") else None

active_dofs = self.dofs(active=True)

read_only_tensor = cast(torch.Tensor, active_dofs.read_only)
Expand Down Expand Up @@ -714,7 +723,7 @@ def learn(
n: int = 1,
iterations: int = 1,
upsample: int = 1,
train: bool | None = None,
force_train: bool = False,
append: bool = True,
hypers: str | None = None,
route: bool = True,
Expand Down Expand Up @@ -768,7 +777,7 @@ def learn(
metadata = {
key: new_table.loc[:, key].tolist() for key in new_table.columns if (key not in x) and (key not in y)
}
self.tell(x=x, y=y, metadata=metadata, append=append, train=train)
self.tell(x=x, y=y, metadata=metadata, append=append, force_train=force_train)

def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = 2**16):
"""
Expand Down Expand Up @@ -1040,8 +1049,8 @@ def _set_hypers(self, hypers: dict[str, Any]):
if not obj.model:
raise RuntimeError(f"Expected {obj} to have a constructed model.")
obj.model.load_state_dict(hypers[obj.name])
if self.validity_constraint:
self.validity_constraint.load_state_dict(hypers["validity_constraint"])
if self.validity_probability:
self.validity_probability.load_state_dict(hypers["validity_probability"])

@property
def hypers(self) -> dict[str, dict[str, dict[str, torch.Tensor]]]:
Expand Down Expand Up @@ -1139,13 +1148,18 @@ def plot_objectives(self, axes: tuple[int, int] = (0, 1), **kwargs) -> None:
axes :
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
"""
if len(self.dofs(active=True, read_only=False)) == 1:
if len(self.objectives(active=True, fitness=True)) > 0:
plotting._plot_fitness_objs_one_dof(self, **kwargs)
if len(self.objectives(active=True, constraint=True)) > 0:
plotting._plot_constraint_objs_one_dof(self, **kwargs)
self.update_models()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should be potentially re-training models here. I think it should be up to the user to decide if they want their models updated prior to plotting. They should really do this instead:

agent.update_models()
agent.plot_objectives()
Suggested change
self.update_models()


plottable_dofs = self.dofs(active=True, read_only=False)
logger.debug(f"Plotting agent with DOFs {self.dofs} and objectives {self.objectives}")
if len(plottable_dofs) == 0:
raise ValueError("To plot agent objectives, at least one writeable DOF must be active.")
elif len(plottable_dofs) == 1:
plotting._plot_objs_one_dof(self, **kwargs)
elif len(plottable_dofs) == 2:
plotting._plot_objs_many_dofs(self, gridded=True, axes=axes, **kwargs)
else:
plotting._plot_objs_many_dofs(self, axes=axes, **kwargs)
plotting._plot_objs_many_dofs(self, gridded=False, axes=axes, **kwargs)

def plot_acquisition(self, acqf: str = "ei", axes: tuple[int, int] = (0, 1), **kwargs) -> None:
"""Plot an acquisition function over test inputs sampling the limits of the parameter space.
Expand All @@ -1157,6 +1171,8 @@ def plot_acquisition(self, acqf: str = "ei", axes: tuple[int, int] = (0, 1), **k
axes :
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
"""
self.update_models()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here as other comment.

Suggested change
self.update_models()


if len(self.dofs(active=True, read_only=False)) == 1:
plotting._plot_acqf_one_dof(self, acqfs=np.atleast_1d(acqf), **kwargs)
else:
Expand All @@ -1170,6 +1186,8 @@ def plot_validity(self, axes: tuple[int, int] = (0, 1), **kwargs) -> None:
axes :
A tuple specifying which DOFs to plot as a function of. Can be either an int or the name of DOFs.
"""
self.update_models()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here as other comment.

Suggested change
self.update_models()


if len(self.dofs(active=True, read_only=False)) == 1:
plotting._plot_valid_one_dof(self, **kwargs)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/blop/bayesian/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(

self.trained: bool = False

def probabilities(self, x: torch.Tensor, n_samples: int = 1024) -> torch.Tensor:
def probabilities(self, x: torch.Tensor, n_samples: int = 256) -> torch.Tensor:
"""
Takes in a (..., m) dimension tensor and returns a (..., n_classes) tensor
"""
Expand Down
20 changes: 12 additions & 8 deletions src/blop/dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,21 +392,25 @@ def transform(self, X: torch.Tensor) -> torch.Tensor:
"""
Transform X to the transformed unit hypercube.
"""
if X.shape[-1] != len(self):
raise ValueError(f"Cannot transform points with shape {X.shape} using DOFs with dimension {len(self)}.")
active_dofs = self(active=True)
if X.shape[-1] != len(active_dofs):
raise ValueError(
f"Cannot transform points with shape {X.shape} using DOFs with active dimension {len(active_dofs)}."
)

return torch.cat([dof._transform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(self.dofs)], dim=-1)
return torch.cat([dof._transform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(active_dofs)], dim=-1)

def untransform(self, X: torch.Tensor) -> torch.Tensor:
"""
Transform the transformed unit hypercube to the search domain.
"""
if X.shape[-1] != len(self):
raise ValueError(f"Cannot untransform points with shape {X.shape} using DOFs with dimension {len(self)}.")
active_dofs = self(active=True)
if X.shape[-1] != len(active_dofs):
raise ValueError(
f"Cannot untransform points with shape {X.shape} using DOFs with active dimension {len(active_dofs)}."
)

return torch.cat(
[dof._untransform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(self.subset(active=True))], dim=-1
)
return torch.cat([dof._untransform(X[..., i]).unsqueeze(-1) for i, dof in enumerate(active_dofs)], dim=-1)

@property
def readback(self) -> list[Any]:
Expand Down
Loading
Loading