From 217e55bd5c5032ddb33ae724a219f5b2fbdf2d57 Mon Sep 17 00:00:00 2001 From: ktschuett Date: Wed, 9 Nov 2022 14:18:07 +0100 Subject: [PATCH 1/2] Fix configs and docstrings --- src/schnetpack/atomistic/atomwise.py | 2 +- src/schnetpack/cli.py | 2 +- .../configs/callbacks/earlystopping.yaml | 4 +- src/schnetpack/configs/callbacks/ema.yaml | 2 +- src/schnetpack/configs/data/qm9.yaml | 4 +- .../{qm9_energy.yaml => qm9_atomwise.yaml} | 3 +- .../configs/experiment/qm9_dipole.yaml | 44 +++++++++++++++++++ .../model/representation/field_schnet.yaml | 7 ++- .../configs/model/representation/painn.yaml | 7 ++- .../representation/radial_basis/bessel.yaml | 3 ++ .../representation/radial_basis/gaussian.yaml | 3 ++ .../configs/model/representation/schnet.yaml | 9 ++-- .../configs/model/representation/so3net.yaml | 7 ++- .../task/scheduler/reduce_on_plateau.yaml | 4 +- src/schnetpack/data/splitting.py | 1 - src/schnetpack/datasets/qm9.py | 6 +-- src/schnetpack/representation/so3net.py | 1 + 17 files changed, 78 insertions(+), 31 deletions(-) rename src/schnetpack/configs/experiment/{qm9_energy.yaml => qm9_atomwise.yaml} (94%) create mode 100644 src/schnetpack/configs/experiment/qm9_dipole.yaml create mode 100644 src/schnetpack/configs/model/representation/radial_basis/bessel.yaml create mode 100644 src/schnetpack/configs/model/representation/radial_basis/gaussian.yaml diff --git a/src/schnetpack/atomistic/atomwise.py b/src/schnetpack/atomistic/atomwise.py index 197dec7f0..c1ce4f6a3 100644 --- a/src/schnetpack/atomistic/atomwise.py +++ b/src/schnetpack/atomistic/atomwise.py @@ -127,7 +127,7 @@ def __init__( return_charges: If true, return latent partial charges dipole_key: the key under which the dipoles will be stored charges_key: the key under which partial charges will be stored - correct_charges: If true, forces the sum of partial charges to be the the total charge, if provided, + correct_charges: If true, forces the sum of partial charges to be the total charge, if provided, and zero otherwise. use_vector_representation: If true, use vector representation to predict local, atomic dipoles. diff --git a/src/schnetpack/cli.py b/src/schnetpack/cli.py index 30762c6f0..532f9c1fc 100644 --- a/src/schnetpack/cli.py +++ b/src/schnetpack/cli.py @@ -40,7 +40,7 @@ def train(config: DictConfig): """ print(header) - log.info("Runnning on host: " + str(socket.gethostname())) + log.info("Running on host: " + str(socket.gethostname())) if OmegaConf.is_missing(config, "run.data_dir"): log.error( diff --git a/src/schnetpack/configs/callbacks/earlystopping.yaml b/src/schnetpack/configs/callbacks/earlystopping.yaml index 05a8900c7..80e366a37 100644 --- a/src/schnetpack/configs/callbacks/earlystopping.yaml +++ b/src/schnetpack/configs/callbacks/earlystopping.yaml @@ -1,6 +1,6 @@ early_stopping: _target_: pytorch_lightning.callbacks.EarlyStopping monitor: "val_loss" # name of the logged metric which determines when model is improving - patience: 100 # how many epochs of not improving until training stops + patience: 150 # how many epochs of not improving until training stops mode: "min" # can be "max" or "min" - min_delta: 1e-5 # minimum change in the monitored metric needed to qualify as an improvement \ No newline at end of file + min_delta: 0.0 # minimum change in the monitored metric needed to qualify as an improvement \ No newline at end of file diff --git a/src/schnetpack/configs/callbacks/ema.yaml b/src/schnetpack/configs/callbacks/ema.yaml index 36fcd7425..9d8a11294 100644 --- a/src/schnetpack/configs/callbacks/ema.yaml +++ b/src/schnetpack/configs/callbacks/ema.yaml @@ -1,3 +1,3 @@ ema: _target_: schnetpack.train.ExponentialMovingAverage - decay: 0.995 \ No newline at end of file + decay: 0.9 \ No newline at end of file diff --git a/src/schnetpack/configs/data/qm9.yaml b/src/schnetpack/configs/data/qm9.yaml index 306ac08cc..9de33eb6a 100644 --- a/src/schnetpack/configs/data/qm9.yaml +++ b/src/schnetpack/configs/data/qm9.yaml @@ -5,8 +5,8 @@ _target_: schnetpack.datasets.QM9 datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml batch_size: 100 -num_train: 110000 -num_val: 10000 +num_train: 105000 +num_val: 5000 remove_uncharacterized: False # convert to typically used units diff --git a/src/schnetpack/configs/experiment/qm9_energy.yaml b/src/schnetpack/configs/experiment/qm9_atomwise.yaml similarity index 94% rename from src/schnetpack/configs/experiment/qm9_energy.yaml rename to src/schnetpack/configs/experiment/qm9_atomwise.yaml index 4b6358f59..81567a1b4 100644 --- a/src/schnetpack/configs/experiment/qm9_energy.yaml +++ b/src/schnetpack/configs/experiment/qm9_atomwise.yaml @@ -11,6 +11,7 @@ globals: cutoff: 5. lr: 5e-4 property: energy_U0 + aggregation: sum data: transforms: @@ -28,7 +29,7 @@ model: - _target_: schnetpack.atomistic.Atomwise output_key: ${globals.property} n_in: ${model.representation.n_atom_basis} - aggregation_mode: sum + aggregation_mode: ${globals.aggregation} postprocessors: - _target_: schnetpack.transform.CastTo64 - _target_: schnetpack.transform.AddOffsets diff --git a/src/schnetpack/configs/experiment/qm9_dipole.yaml b/src/schnetpack/configs/experiment/qm9_dipole.yaml new file mode 100644 index 000000000..48ab3488d --- /dev/null +++ b/src/schnetpack/configs/experiment/qm9_dipole.yaml @@ -0,0 +1,44 @@ +# @package _global_ + +defaults: + - override /model: nnp + - override /data: qm9 + +run: + experiment: qm9_${globals.property} + +globals: + cutoff: 5. + lr: 5e-4 + property: dipole_moment + +data: + transforms: + - _target_: schnetpack.transform.SubtractCenterOfMass + - _target_: schnetpack.transform.MatScipyNeighborList + cutoff: ${globals.cutoff} + - _target_: schnetpack.transform.CastTo32 + +model: + output_modules: + - _target_: schnetpack.atomistic.DipoleMoment + dipole_key: ${globals.property} + n_in: ${model.representation.n_atom_basis} + predict_magnitude: True + use_vector_representation: False + postprocessors: + - _target_: schnetpack.transform.CastTo64 + +task: + outputs: + - _target_: schnetpack.task.ModelOutput + name: ${globals.property} + loss_fn: + _target_: torch.nn.MSELoss + metrics: + mae: + _target_: torchmetrics.regression.MeanAbsoluteError + rmse: + _target_: torchmetrics.regression.MeanSquaredError + squared: False + loss_weight: 1. \ No newline at end of file diff --git a/src/schnetpack/configs/model/representation/field_schnet.yaml b/src/schnetpack/configs/model/representation/field_schnet.yaml index 229635774..d508c479c 100644 --- a/src/schnetpack/configs/model/representation/field_schnet.yaml +++ b/src/schnetpack/configs/model/representation/field_schnet.yaml @@ -1,13 +1,12 @@ +defaults: + - radial_basis: gaussian + _target_: schnetpack.representation.FieldSchNet n_atom_basis: 128 n_interactions: 5 external_fields: [] response_properties: ${globals.response_properties} shared_interactions: False -radial_basis: - _target_: schnetpack.nn.radial.GaussianRBF - n_rbf: 20 - cutoff: ${globals.cutoff} cutoff_fn: _target_: schnetpack.nn.cutoff.CosineCutoff cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/configs/model/representation/painn.yaml b/src/schnetpack/configs/model/representation/painn.yaml index a72515d0f..b8fecf3f5 100644 --- a/src/schnetpack/configs/model/representation/painn.yaml +++ b/src/schnetpack/configs/model/representation/painn.yaml @@ -1,12 +1,11 @@ +defaults: + - radial_basis: gaussian + _target_: schnetpack.representation.PaiNN n_atom_basis: 128 n_interactions: 3 shared_interactions: False shared_filters: False -radial_basis: - _target_: schnetpack.nn.radial.GaussianRBF - n_rbf: 20 - cutoff: ${globals.cutoff} cutoff_fn: _target_: schnetpack.nn.cutoff.CosineCutoff cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/configs/model/representation/radial_basis/bessel.yaml b/src/schnetpack/configs/model/representation/radial_basis/bessel.yaml new file mode 100644 index 000000000..6346937a5 --- /dev/null +++ b/src/schnetpack/configs/model/representation/radial_basis/bessel.yaml @@ -0,0 +1,3 @@ +_target_: schnetpack.nn.radial.BesselRBF +n_rbf: 20 +cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/configs/model/representation/radial_basis/gaussian.yaml b/src/schnetpack/configs/model/representation/radial_basis/gaussian.yaml new file mode 100644 index 000000000..64c2305f8 --- /dev/null +++ b/src/schnetpack/configs/model/representation/radial_basis/gaussian.yaml @@ -0,0 +1,3 @@ +_target_: schnetpack.nn.radial.GaussianRBF +n_rbf: 20 +cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/configs/model/representation/schnet.yaml b/src/schnetpack/configs/model/representation/schnet.yaml index ee998aff8..f0d398ee0 100644 --- a/src/schnetpack/configs/model/representation/schnet.yaml +++ b/src/schnetpack/configs/model/representation/schnet.yaml @@ -1,10 +1,9 @@ +defaults: + - radial_basis: gaussian + _target_: schnetpack.representation.SchNet n_atom_basis: 128 -n_interactions: 6 -radial_basis: - _target_: schnetpack.nn.radial.GaussianRBF - n_rbf: 20 - cutoff: ${globals.cutoff} +n_interactions: 3 cutoff_fn: _target_: schnetpack.nn.cutoff.CosineCutoff cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/configs/model/representation/so3net.yaml b/src/schnetpack/configs/model/representation/so3net.yaml index 2935ec4ba..2b52bbf04 100644 --- a/src/schnetpack/configs/model/representation/so3net.yaml +++ b/src/schnetpack/configs/model/representation/so3net.yaml @@ -1,13 +1,12 @@ +defaults: + - radial_basis: gaussian + _target_: schnetpack.representation.SO3net _recursive_: False n_atom_basis: 64 n_interactions: 3 lmax: 2 shared_interactions: False -radial_basis: - _target_: schnetpack.nn.radial.GaussianRBF - n_rbf: 20 - cutoff: ${globals.cutoff} cutoff_fn: _target_: schnetpack.nn.cutoff.CosineCutoff cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml b/src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml index 491530352..57f561a01 100644 --- a/src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml +++ b/src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml @@ -4,8 +4,8 @@ scheduler_monitor: val_loss scheduler_args: mode: min factor: 0.8 - patience: 50 - threshold: 1e-4 + patience: 75 + threshold: 0.0 threshold_mode: rel cooldown: 10 min_lr: 0.0 diff --git a/src/schnetpack/data/splitting.py b/src/schnetpack/data/splitting.py index 0d2d1b4bf..dfd1b98b0 100644 --- a/src/schnetpack/data/splitting.py +++ b/src/schnetpack/data/splitting.py @@ -166,5 +166,4 @@ def split(self, dataset, *split_sizes): ) for i, split_idx in zip(split_partition_idx[src], partition_split_indices): split_indices[i] = np.array(partition)[split_idx].tolist() - return split_indices diff --git a/src/schnetpack/datasets/qm9.py b/src/schnetpack/datasets/qm9.py index 2c913cdb9..24a32fd1b 100644 --- a/src/schnetpack/datasets/qm9.py +++ b/src/schnetpack/datasets/qm9.py @@ -133,12 +133,12 @@ def prepare_data(self): QM9.A: "GHz", QM9.B: "GHz", QM9.C: "GHz", - QM9.mu: "D", - QM9.alpha: "a0^3", + QM9.mu: "Debye", + QM9.alpha: "a0 a0 a0", QM9.homo: "Ha", QM9.lumo: "Ha", QM9.gap: "Ha", - QM9.r2: "a0^2", + QM9.r2: "a0 a0", QM9.zpve: "Ha", QM9.U0: "Ha", QM9.U: "Ha", diff --git a/src/schnetpack/representation/so3net.py b/src/schnetpack/representation/so3net.py index 7a10ba0f7..c8ecb7e5b 100644 --- a/src/schnetpack/representation/so3net.py +++ b/src/schnetpack/representation/so3net.py @@ -118,5 +118,6 @@ def forward(self, inputs: Dict[str, torch.Tensor]): x = x + dx inputs["scalar_representation"] = x[:, 0] + inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1) inputs["multipole_representation"] = x return inputs From fa6f10aca968f6f609a290c5723a84222648bebf Mon Sep 17 00:00:00 2001 From: ktschuett Date: Wed, 9 Nov 2022 15:45:49 +0100 Subject: [PATCH 2/2] Add docstring --- src/schnetpack/configs/model/representation/schnet.yaml | 2 +- src/schnetpack/representation/so3net.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/schnetpack/configs/model/representation/schnet.yaml b/src/schnetpack/configs/model/representation/schnet.yaml index f0d398ee0..1eb60bdbd 100644 --- a/src/schnetpack/configs/model/representation/schnet.yaml +++ b/src/schnetpack/configs/model/representation/schnet.yaml @@ -3,7 +3,7 @@ defaults: _target_: schnetpack.representation.SchNet n_atom_basis: 128 -n_interactions: 3 +n_interactions: 6 cutoff_fn: _target_: schnetpack.nn.cutoff.CosineCutoff cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/representation/so3net.py b/src/schnetpack/representation/so3net.py index c8ecb7e5b..7ff8cb160 100644 --- a/src/schnetpack/representation/so3net.py +++ b/src/schnetpack/representation/so3net.py @@ -118,6 +118,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]): x = x + dx inputs["scalar_representation"] = x[:, 0] + # extract cartesian vector from multipoles: [y, z, x] -> [x, y, z] inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1) inputs["multipole_representation"] = x return inputs