Skip to content

Fix configs and docstrings #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/schnetpack/atomistic/atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
return_charges: If true, return latent partial charges
dipole_key: the key under which the dipoles will be stored
charges_key: the key under which partial charges will be stored
correct_charges: If true, forces the sum of partial charges to be the the total charge, if provided,
correct_charges: If true, forces the sum of partial charges to be the total charge, if provided,
and zero otherwise.
use_vector_representation: If true, use vector representation to predict local,
atomic dipoles.
Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def train(config: DictConfig):

"""
print(header)
log.info("Runnning on host: " + str(socket.gethostname()))
log.info("Running on host: " + str(socket.gethostname()))

if OmegaConf.is_missing(config, "run.data_dir"):
log.error(
Expand Down
4 changes: 2 additions & 2 deletions src/schnetpack/configs/callbacks/earlystopping.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: "val_loss" # name of the logged metric which determines when model is improving
patience: 100 # how many epochs of not improving until training stops
patience: 150 # how many epochs of not improving until training stops
mode: "min" # can be "max" or "min"
min_delta: 1e-5 # minimum change in the monitored metric needed to qualify as an improvement
min_delta: 0.0 # minimum change in the monitored metric needed to qualify as an improvement
2 changes: 1 addition & 1 deletion src/schnetpack/configs/callbacks/ema.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
ema:
_target_: schnetpack.train.ExponentialMovingAverage
decay: 0.995
decay: 0.9
4 changes: 2 additions & 2 deletions src/schnetpack/configs/data/qm9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ _target_: schnetpack.datasets.QM9

datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml
batch_size: 100
num_train: 110000
num_val: 10000
num_train: 105000
num_val: 5000
remove_uncharacterized: False

# convert to typically used units
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ globals:
cutoff: 5.
lr: 5e-4
property: energy_U0
aggregation: sum

data:
transforms:
Expand All @@ -28,7 +29,7 @@ model:
- _target_: schnetpack.atomistic.Atomwise
output_key: ${globals.property}
n_in: ${model.representation.n_atom_basis}
aggregation_mode: sum
aggregation_mode: ${globals.aggregation}
postprocessors:
- _target_: schnetpack.transform.CastTo64
- _target_: schnetpack.transform.AddOffsets
Expand Down
44 changes: 44 additions & 0 deletions src/schnetpack/configs/experiment/qm9_dipole.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# @package _global_

defaults:
- override /model: nnp
- override /data: qm9

run:
experiment: qm9_${globals.property}

globals:
cutoff: 5.
lr: 5e-4
property: dipole_moment

data:
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.MatScipyNeighborList
cutoff: ${globals.cutoff}
- _target_: schnetpack.transform.CastTo32

model:
output_modules:
- _target_: schnetpack.atomistic.DipoleMoment
dipole_key: ${globals.property}
n_in: ${model.representation.n_atom_basis}
predict_magnitude: True
use_vector_representation: False
postprocessors:
- _target_: schnetpack.transform.CastTo64

task:
outputs:
- _target_: schnetpack.task.ModelOutput
name: ${globals.property}
loss_fn:
_target_: torch.nn.MSELoss
metrics:
mae:
_target_: torchmetrics.regression.MeanAbsoluteError
rmse:
_target_: torchmetrics.regression.MeanSquaredError
squared: False
loss_weight: 1.
7 changes: 3 additions & 4 deletions src/schnetpack/configs/model/representation/field_schnet.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
defaults:
- radial_basis: gaussian

_target_: schnetpack.representation.FieldSchNet
n_atom_basis: 128
n_interactions: 5
external_fields: []
response_properties: ${globals.response_properties}
shared_interactions: False
radial_basis:
_target_: schnetpack.nn.radial.GaussianRBF
n_rbf: 20
cutoff: ${globals.cutoff}
cutoff_fn:
_target_: schnetpack.nn.cutoff.CosineCutoff
cutoff: ${globals.cutoff}
7 changes: 3 additions & 4 deletions src/schnetpack/configs/model/representation/painn.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
defaults:
- radial_basis: gaussian

_target_: schnetpack.representation.PaiNN
n_atom_basis: 128
n_interactions: 3
shared_interactions: False
shared_filters: False
radial_basis:
_target_: schnetpack.nn.radial.GaussianRBF
n_rbf: 20
cutoff: ${globals.cutoff}
cutoff_fn:
_target_: schnetpack.nn.cutoff.CosineCutoff
cutoff: ${globals.cutoff}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: schnetpack.nn.radial.BesselRBF
n_rbf: 20
cutoff: ${globals.cutoff}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: schnetpack.nn.radial.GaussianRBF
n_rbf: 20
cutoff: ${globals.cutoff}
7 changes: 3 additions & 4 deletions src/schnetpack/configs/model/representation/schnet.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
defaults:
- radial_basis: gaussian

_target_: schnetpack.representation.SchNet
n_atom_basis: 128
n_interactions: 6
radial_basis:
_target_: schnetpack.nn.radial.GaussianRBF
n_rbf: 20
cutoff: ${globals.cutoff}
cutoff_fn:
_target_: schnetpack.nn.cutoff.CosineCutoff
cutoff: ${globals.cutoff}
7 changes: 3 additions & 4 deletions src/schnetpack/configs/model/representation/so3net.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
defaults:
- radial_basis: gaussian

_target_: schnetpack.representation.SO3net
_recursive_: False
n_atom_basis: 64
n_interactions: 3
lmax: 2
shared_interactions: False
radial_basis:
_target_: schnetpack.nn.radial.GaussianRBF
n_rbf: 20
cutoff: ${globals.cutoff}
cutoff_fn:
_target_: schnetpack.nn.cutoff.CosineCutoff
cutoff: ${globals.cutoff}
4 changes: 2 additions & 2 deletions src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ scheduler_monitor: val_loss
scheduler_args:
mode: min
factor: 0.8
patience: 50
threshold: 1e-4
patience: 75
threshold: 0.0
threshold_mode: rel
cooldown: 10
min_lr: 0.0
Expand Down
1 change: 0 additions & 1 deletion src/schnetpack/data/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,4 @@ def split(self, dataset, *split_sizes):
)
for i, split_idx in zip(split_partition_idx[src], partition_split_indices):
split_indices[i] = np.array(partition)[split_idx].tolist()

return split_indices
6 changes: 3 additions & 3 deletions src/schnetpack/datasets/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ def prepare_data(self):
QM9.A: "GHz",
QM9.B: "GHz",
QM9.C: "GHz",
QM9.mu: "D",
QM9.alpha: "a0^3",
QM9.mu: "Debye",
QM9.alpha: "a0 a0 a0",
QM9.homo: "Ha",
QM9.lumo: "Ha",
QM9.gap: "Ha",
QM9.r2: "a0^2",
QM9.r2: "a0 a0",
QM9.zpve: "Ha",
QM9.U0: "Ha",
QM9.U: "Ha",
Expand Down
2 changes: 2 additions & 0 deletions src/schnetpack/representation/so3net.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]):
x = x + dx

inputs["scalar_representation"] = x[:, 0]
# extract cartesian vector from multipoles: [y, z, x] -> [x, y, z]
inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1)
inputs["multipole_representation"] = x
return inputs