Skip to content

Commit 2e7332e

Browse files
authored
Merge pull request #459 from atomistic-machine-learning/kts/multipoles
Fix configs and docstrings
2 parents f81f641 + fa6f10a commit 2e7332e

File tree

17 files changed

+78
-30
lines changed

17 files changed

+78
-30
lines changed

Diff for: src/schnetpack/atomistic/atomwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(
127127
return_charges: If true, return latent partial charges
128128
dipole_key: the key under which the dipoles will be stored
129129
charges_key: the key under which partial charges will be stored
130-
correct_charges: If true, forces the sum of partial charges to be the the total charge, if provided,
130+
correct_charges: If true, forces the sum of partial charges to be the total charge, if provided,
131131
and zero otherwise.
132132
use_vector_representation: If true, use vector representation to predict local,
133133
atomic dipoles.

Diff for: src/schnetpack/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def train(config: DictConfig):
4040
4141
"""
4242
print(header)
43-
log.info("Runnning on host: " + str(socket.gethostname()))
43+
log.info("Running on host: " + str(socket.gethostname()))
4444

4545
if OmegaConf.is_missing(config, "run.data_dir"):
4646
log.error(

Diff for: src/schnetpack/configs/callbacks/earlystopping.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
early_stopping:
22
_target_: pytorch_lightning.callbacks.EarlyStopping
33
monitor: "val_loss" # name of the logged metric which determines when model is improving
4-
patience: 100 # how many epochs of not improving until training stops
4+
patience: 150 # how many epochs of not improving until training stops
55
mode: "min" # can be "max" or "min"
6-
min_delta: 1e-5 # minimum change in the monitored metric needed to qualify as an improvement
6+
min_delta: 0.0 # minimum change in the monitored metric needed to qualify as an improvement

Diff for: src/schnetpack/configs/callbacks/ema.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
ema:
22
_target_: schnetpack.train.ExponentialMovingAverage
3-
decay: 0.995
3+
decay: 0.9

Diff for: src/schnetpack/configs/data/qm9.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ _target_: schnetpack.datasets.QM9
55

66
datapath: ${run.data_dir}/qm9.db # data_dir is specified in train.yaml
77
batch_size: 100
8-
num_train: 110000
9-
num_val: 10000
8+
num_train: 105000
9+
num_val: 5000
1010
remove_uncharacterized: False
1111

1212
# convert to typically used units

Diff for: src/schnetpack/configs/experiment/qm9_energy.yaml renamed to src/schnetpack/configs/experiment/qm9_atomwise.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ globals:
1111
cutoff: 5.
1212
lr: 5e-4
1313
property: energy_U0
14+
aggregation: sum
1415

1516
data:
1617
transforms:
@@ -28,7 +29,7 @@ model:
2829
- _target_: schnetpack.atomistic.Atomwise
2930
output_key: ${globals.property}
3031
n_in: ${model.representation.n_atom_basis}
31-
aggregation_mode: sum
32+
aggregation_mode: ${globals.aggregation}
3233
postprocessors:
3334
- _target_: schnetpack.transform.CastTo64
3435
- _target_: schnetpack.transform.AddOffsets

Diff for: src/schnetpack/configs/experiment/qm9_dipole.yaml

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# @package _global_
2+
3+
defaults:
4+
- override /model: nnp
5+
- override /data: qm9
6+
7+
run:
8+
experiment: qm9_${globals.property}
9+
10+
globals:
11+
cutoff: 5.
12+
lr: 5e-4
13+
property: dipole_moment
14+
15+
data:
16+
transforms:
17+
- _target_: schnetpack.transform.SubtractCenterOfMass
18+
- _target_: schnetpack.transform.MatScipyNeighborList
19+
cutoff: ${globals.cutoff}
20+
- _target_: schnetpack.transform.CastTo32
21+
22+
model:
23+
output_modules:
24+
- _target_: schnetpack.atomistic.DipoleMoment
25+
dipole_key: ${globals.property}
26+
n_in: ${model.representation.n_atom_basis}
27+
predict_magnitude: True
28+
use_vector_representation: False
29+
postprocessors:
30+
- _target_: schnetpack.transform.CastTo64
31+
32+
task:
33+
outputs:
34+
- _target_: schnetpack.task.ModelOutput
35+
name: ${globals.property}
36+
loss_fn:
37+
_target_: torch.nn.MSELoss
38+
metrics:
39+
mae:
40+
_target_: torchmetrics.regression.MeanAbsoluteError
41+
rmse:
42+
_target_: torchmetrics.regression.MeanSquaredError
43+
squared: False
44+
loss_weight: 1.
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
defaults:
2+
- radial_basis: gaussian
3+
14
_target_: schnetpack.representation.FieldSchNet
25
n_atom_basis: 128
36
n_interactions: 5
47
external_fields: []
58
response_properties: ${globals.response_properties}
69
shared_interactions: False
7-
radial_basis:
8-
_target_: schnetpack.nn.radial.GaussianRBF
9-
n_rbf: 20
10-
cutoff: ${globals.cutoff}
1110
cutoff_fn:
1211
_target_: schnetpack.nn.cutoff.CosineCutoff
1312
cutoff: ${globals.cutoff}
+3-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1+
defaults:
2+
- radial_basis: gaussian
3+
14
_target_: schnetpack.representation.PaiNN
25
n_atom_basis: 128
36
n_interactions: 3
47
shared_interactions: False
58
shared_filters: False
6-
radial_basis:
7-
_target_: schnetpack.nn.radial.GaussianRBF
8-
n_rbf: 20
9-
cutoff: ${globals.cutoff}
109
cutoff_fn:
1110
_target_: schnetpack.nn.cutoff.CosineCutoff
1211
cutoff: ${globals.cutoff}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_target_: schnetpack.nn.radial.BesselRBF
2+
n_rbf: 20
3+
cutoff: ${globals.cutoff}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_target_: schnetpack.nn.radial.GaussianRBF
2+
n_rbf: 20
3+
cutoff: ${globals.cutoff}
+3-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1+
defaults:
2+
- radial_basis: gaussian
3+
14
_target_: schnetpack.representation.SchNet
25
n_atom_basis: 128
36
n_interactions: 6
4-
radial_basis:
5-
_target_: schnetpack.nn.radial.GaussianRBF
6-
n_rbf: 20
7-
cutoff: ${globals.cutoff}
87
cutoff_fn:
98
_target_: schnetpack.nn.cutoff.CosineCutoff
109
cutoff: ${globals.cutoff}
+3-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
defaults:
2+
- radial_basis: gaussian
3+
14
_target_: schnetpack.representation.SO3net
25
_recursive_: False
36
n_atom_basis: 64
47
n_interactions: 3
58
lmax: 2
69
shared_interactions: False
7-
radial_basis:
8-
_target_: schnetpack.nn.radial.GaussianRBF
9-
n_rbf: 20
10-
cutoff: ${globals.cutoff}
1110
cutoff_fn:
1211
_target_: schnetpack.nn.cutoff.CosineCutoff
1312
cutoff: ${globals.cutoff}

Diff for: src/schnetpack/configs/task/scheduler/reduce_on_plateau.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ scheduler_monitor: val_loss
44
scheduler_args:
55
mode: min
66
factor: 0.8
7-
patience: 50
8-
threshold: 1e-4
7+
patience: 75
8+
threshold: 0.0
99
threshold_mode: rel
1010
cooldown: 10
1111
min_lr: 0.0

Diff for: src/schnetpack/data/splitting.py

-1
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,4 @@ def split(self, dataset, *split_sizes):
166166
)
167167
for i, split_idx in zip(split_partition_idx[src], partition_split_indices):
168168
split_indices[i] = np.array(partition)[split_idx].tolist()
169-
170169
return split_indices

Diff for: src/schnetpack/datasets/qm9.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,12 @@ def prepare_data(self):
133133
QM9.A: "GHz",
134134
QM9.B: "GHz",
135135
QM9.C: "GHz",
136-
QM9.mu: "D",
137-
QM9.alpha: "a0^3",
136+
QM9.mu: "Debye",
137+
QM9.alpha: "a0 a0 a0",
138138
QM9.homo: "Ha",
139139
QM9.lumo: "Ha",
140140
QM9.gap: "Ha",
141-
QM9.r2: "a0^2",
141+
QM9.r2: "a0 a0",
142142
QM9.zpve: "Ha",
143143
QM9.U0: "Ha",
144144
QM9.U: "Ha",

Diff for: src/schnetpack/representation/so3net.py

+2
Original file line numberDiff line numberDiff line change
@@ -118,5 +118,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]):
118118
x = x + dx
119119

120120
inputs["scalar_representation"] = x[:, 0]
121+
# extract cartesian vector from multipoles: [y, z, x] -> [x, y, z]
122+
inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1)
121123
inputs["multipole_representation"] = x
122124
return inputs

0 commit comments

Comments
 (0)