From e008c5e695466d6e3bfb8f0150269f0a01d9910d Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 21 Jul 2022 16:28:37 +0200 Subject: [PATCH 1/7] Added matscipy package to setup --- setup.py | 1 + .../md/md_configs/calculator/neighbor_list/matscipy.yaml | 6 ++++++ 2 files changed, 7 insertions(+) create mode 100644 src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml diff --git a/setup.py b/setup.py index 080d42809..6617d1858 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def read(fname): "rich", "fasteners", "dirsync", + "matscipy @ git+ssh://git@github.com/libAtoms/matscipy.git", ], include_package_data=True, extras_require={"test": ["pytest", "pytest-datadir", "pytest-benchmark"]}, diff --git a/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml b/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml new file mode 100644 index 000000000..4416c3e47 --- /dev/null +++ b/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml @@ -0,0 +1,6 @@ +_target_: schnetpack.md.neighborlist_md.NeighborListMD +cutoff: ??? +cutoff_shell: 2.0 +requires_triples: false +base_nbl: schnetpack.transform.ASENeighborList +collate_fn: schnetpack.data.loader._atoms_collate_fn From 12b94254fe632cd93684c151c25df01eb8c709cc Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 21 Jul 2022 16:29:19 +0200 Subject: [PATCH 2/7] Added transform for matscipy neighborlist --- src/schnetpack/transform/neighborlist.py | 41 ++++++++++++++++++++---- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/schnetpack/transform/neighborlist.py b/src/schnetpack/transform/neighborlist.py index b8a735104..d226d7dbe 100644 --- a/src/schnetpack/transform/neighborlist.py +++ b/src/schnetpack/transform/neighborlist.py @@ -2,15 +2,16 @@ import torch import shutil from ase import Atoms -from ase.neighborlist import neighbor_list +from ase.neighborlist import neighbor_list as ase_neighbor_list +from matscipy.neighbours import neighbour_list as msp_neighbor_list from .base import Transform from dirsync import sync import numpy as np from typing import Optional, Dict, List, Type, Any, Union - __all__ = [ "ASENeighborList", + "MatScipyNeighborList", "TorchNeighborList", "CountNeighbors", "CollectAtomTriples", @@ -55,7 +56,6 @@ def forward( self, inputs: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: - inputs = self.neighbor_list(inputs) for postprocess in self.nbh_postprocessing: inputs = postprocess(inputs) @@ -204,8 +204,8 @@ def forward( pbc = inputs[properties.pbc] idx_i, idx_j, offset = self._build_neighbor_list(Z, R, cell, pbc, self._cutoff) - inputs[properties.idx_i] = idx_i.detach() - inputs[properties.idx_j] = idx_j.detach() + inputs[properties.idx_i] = idx_i.detach().long() + inputs[properties.idx_j] = idx_j.detach().long() inputs[properties.offsets] = offset return inputs @@ -229,11 +229,40 @@ class ASENeighborList(NeighborListTransform): def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff): at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc) - idx_i, idx_j, S = neighbor_list("ijS", at, cutoff, self_interaction=False) + idx_i, idx_j, S = ase_neighbor_list("ijS", at, cutoff, self_interaction=False) + idx_i = torch.from_numpy(idx_i) + idx_j = torch.from_numpy(idx_j) + S = torch.from_numpy(S).to(dtype=positions.dtype) + offset = torch.mm(S, cell) + return idx_i, idx_j, offset + + +class MatScipyNeighborList(NeighborListTransform): + """ + Neighborlist using the efficient implementation of the Matscipy package + (https://github.com/libAtoms/matscipy). + """ + + def _build_neighbor_list( + self, Z, positions, cell, pbc, cutoff, eps=1e-6, buffer=1.0 + ): + at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc) + + # Add cell if none is present (volume = 0) + if at.cell.volume < eps: + # max values - min values along xyz augmented by small buffer for stability + new_cell = np.ptp(at.positions, axis=0) + buffer + # Set cell and center + at.set_cell(new_cell, scale_atoms=False) + at.center() + + # Compute neighborhood + idx_i, idx_j, S = msp_neighbor_list("ijS", at, cutoff) idx_i = torch.from_numpy(idx_i) idx_j = torch.from_numpy(idx_j) S = torch.from_numpy(S).to(dtype=positions.dtype) offset = torch.mm(S, cell) + return idx_i, idx_j, offset From c30b481a51c4b69677bfd7d0f51f1fb08fcc0f5d Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 21 Jul 2022 16:30:36 +0200 Subject: [PATCH 3/7] Changed neighborlist default to MatScipy --- src/schnetpack/configs/experiment/md17.yaml | 2 +- src/schnetpack/configs/experiment/qm9.yaml | 2 +- src/schnetpack/configs/predict.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/schnetpack/configs/experiment/md17.yaml b/src/schnetpack/configs/experiment/md17.yaml index 5154e8fed..48ecdf469 100644 --- a/src/schnetpack/configs/experiment/md17.yaml +++ b/src/schnetpack/configs/experiment/md17.yaml @@ -22,7 +22,7 @@ data: - _target_: schnetpack.transform.RemoveOffsets property: energy remove_mean: True - - _target_: schnetpack.transform.ASENeighborList + - _target_: schnetpack.transform.MatScipyNeighborList cutoff: ${globals.cutoff} - _target_: schnetpack.transform.CastTo32 diff --git a/src/schnetpack/configs/experiment/qm9.yaml b/src/schnetpack/configs/experiment/qm9.yaml index c4f6b67b5..7f556973b 100644 --- a/src/schnetpack/configs/experiment/qm9.yaml +++ b/src/schnetpack/configs/experiment/qm9.yaml @@ -18,7 +18,7 @@ data: property: ${globals.property} remove_atomrefs: True remove_mean: True - - _target_: schnetpack.transform.ASENeighborList + - _target_: schnetpack.transform.MatScipyNeighborList cutoff: ${globals.cutoff} - _target_: schnetpack.transform.CastTo32 diff --git a/src/schnetpack/configs/predict.yaml b/src/schnetpack/configs/predict.yaml index 7fdb53f9c..4e015edcc 100644 --- a/src/schnetpack/configs/predict.yaml +++ b/src/schnetpack/configs/predict.yaml @@ -15,7 +15,7 @@ data: datapath: ${datapath} transforms: - _target_: schnetpack.transform.SubtractCenterOfMass - - _target_: schnetpack.transform.ASENeighborList + - _target_: schnetpack.transform.MatScipyNeighborList cutoff: ${cutoff} - _target_: schnetpack.transform.CastTo32 From 646eb251773ee4a5911d36cc880ad7c99ebb8cca Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 21 Jul 2022 16:31:13 +0200 Subject: [PATCH 4/7] Added config for matscipy MD neighborlist --- .../md/md_configs/calculator/neighbor_list/matscipy.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml b/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml index 4416c3e47..93e582911 100644 --- a/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml +++ b/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml @@ -2,5 +2,5 @@ _target_: schnetpack.md.neighborlist_md.NeighborListMD cutoff: ??? cutoff_shell: 2.0 requires_triples: false -base_nbl: schnetpack.transform.ASENeighborList +base_nbl: schnetpack.transform.MatScipyNeighborList collate_fn: schnetpack.data.loader._atoms_collate_fn From 2066937c19d8d127cd82a971c808b8062f7ce929 Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 21 Jul 2022 16:31:37 +0200 Subject: [PATCH 5/7] Changed MD neighborlist defaults to matscipy --- src/schnetpack/md/md_configs/calculator/lj.yaml | 2 +- src/schnetpack/md/md_configs/calculator/spk.yaml | 2 +- src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/schnetpack/md/md_configs/calculator/lj.yaml b/src/schnetpack/md/md_configs/calculator/lj.yaml index 472a9620a..743033159 100644 --- a/src/schnetpack/md/md_configs/calculator/lj.yaml +++ b/src/schnetpack/md/md_configs/calculator/lj.yaml @@ -9,4 +9,4 @@ stress_key: stress healing_length: 4.0 #0.3405 defaults: - - neighbor_list: ase \ No newline at end of file + - neighbor_list: matscipy \ No newline at end of file diff --git a/src/schnetpack/md/md_configs/calculator/spk.yaml b/src/schnetpack/md/md_configs/calculator/spk.yaml index 3299dae6b..78864b231 100644 --- a/src/schnetpack/md/md_configs/calculator/spk.yaml +++ b/src/schnetpack/md/md_configs/calculator/spk.yaml @@ -11,4 +11,4 @@ stress_key: null script_model: false defaults: - - neighbor_list: ase \ No newline at end of file + - neighbor_list: matscipy \ No newline at end of file diff --git a/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml b/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml index 642a0957e..3d4815736 100644 --- a/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml +++ b/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml @@ -12,4 +12,4 @@ stress_key: null script_model: false defaults: - - neighbor_list: ase + - neighbor_list: matscipy From e98cae87c8ec73b72e9f0d6f904898b8b49b96b3 Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 21 Jul 2022 16:52:10 +0200 Subject: [PATCH 6/7] Moved long conversion to MatScipyNeighborlist --- src/schnetpack/transform/neighborlist.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/schnetpack/transform/neighborlist.py b/src/schnetpack/transform/neighborlist.py index d226d7dbe..ddf5a1151 100644 --- a/src/schnetpack/transform/neighborlist.py +++ b/src/schnetpack/transform/neighborlist.py @@ -204,8 +204,8 @@ def forward( pbc = inputs[properties.pbc] idx_i, idx_j, offset = self._build_neighbor_list(Z, R, cell, pbc, self._cutoff) - inputs[properties.idx_i] = idx_i.detach().long() - inputs[properties.idx_j] = idx_j.detach().long() + inputs[properties.idx_i] = idx_i.detach() + inputs[properties.idx_j] = idx_j.detach() inputs[properties.offsets] = offset return inputs @@ -258,8 +258,8 @@ def _build_neighbor_list( # Compute neighborhood idx_i, idx_j, S = msp_neighbor_list("ijS", at, cutoff) - idx_i = torch.from_numpy(idx_i) - idx_j = torch.from_numpy(idx_j) + idx_i = torch.from_numpy(idx_i).long() + idx_j = torch.from_numpy(idx_j).long() S = torch.from_numpy(S).to(dtype=positions.dtype) offset = torch.mm(S, cell) From 72a19ebd9349770dca8da54a4a37ed271eb9d1b0 Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 21 Jul 2022 18:11:36 +0200 Subject: [PATCH 7/7] Changed matscipy git to https --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6617d1858..440e938c0 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ def read(fname): "rich", "fasteners", "dirsync", - "matscipy @ git+ssh://git@github.com/libAtoms/matscipy.git", + "matscipy @ git+https://github.com/libAtoms/matscipy.git", ], include_package_data=True, extras_require={"test": ["pytest", "pytest-datadir", "pytest-benchmark"]},