diff --git a/setup.py b/setup.py index 080d42809..440e938c0 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def read(fname): "rich", "fasteners", "dirsync", + "matscipy @ git+https://github.com/libAtoms/matscipy.git", ], include_package_data=True, extras_require={"test": ["pytest", "pytest-datadir", "pytest-benchmark"]}, diff --git a/src/schnetpack/configs/experiment/md17.yaml b/src/schnetpack/configs/experiment/md17.yaml index 5154e8fed..48ecdf469 100644 --- a/src/schnetpack/configs/experiment/md17.yaml +++ b/src/schnetpack/configs/experiment/md17.yaml @@ -22,7 +22,7 @@ data: - _target_: schnetpack.transform.RemoveOffsets property: energy remove_mean: True - - _target_: schnetpack.transform.ASENeighborList + - _target_: schnetpack.transform.MatScipyNeighborList cutoff: ${globals.cutoff} - _target_: schnetpack.transform.CastTo32 diff --git a/src/schnetpack/configs/experiment/qm9.yaml b/src/schnetpack/configs/experiment/qm9.yaml index c4f6b67b5..7f556973b 100644 --- a/src/schnetpack/configs/experiment/qm9.yaml +++ b/src/schnetpack/configs/experiment/qm9.yaml @@ -18,7 +18,7 @@ data: property: ${globals.property} remove_atomrefs: True remove_mean: True - - _target_: schnetpack.transform.ASENeighborList + - _target_: schnetpack.transform.MatScipyNeighborList cutoff: ${globals.cutoff} - _target_: schnetpack.transform.CastTo32 diff --git a/src/schnetpack/configs/predict.yaml b/src/schnetpack/configs/predict.yaml index 7fdb53f9c..4e015edcc 100644 --- a/src/schnetpack/configs/predict.yaml +++ b/src/schnetpack/configs/predict.yaml @@ -15,7 +15,7 @@ data: datapath: ${datapath} transforms: - _target_: schnetpack.transform.SubtractCenterOfMass - - _target_: schnetpack.transform.ASENeighborList + - _target_: schnetpack.transform.MatScipyNeighborList cutoff: ${cutoff} - _target_: schnetpack.transform.CastTo32 diff --git a/src/schnetpack/md/md_configs/calculator/lj.yaml b/src/schnetpack/md/md_configs/calculator/lj.yaml index 472a9620a..743033159 100644 --- a/src/schnetpack/md/md_configs/calculator/lj.yaml +++ b/src/schnetpack/md/md_configs/calculator/lj.yaml @@ -9,4 +9,4 @@ stress_key: stress healing_length: 4.0 #0.3405 defaults: - - neighbor_list: ase \ No newline at end of file + - neighbor_list: matscipy \ No newline at end of file diff --git a/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml b/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml new file mode 100644 index 000000000..93e582911 --- /dev/null +++ b/src/schnetpack/md/md_configs/calculator/neighbor_list/matscipy.yaml @@ -0,0 +1,6 @@ +_target_: schnetpack.md.neighborlist_md.NeighborListMD +cutoff: ??? +cutoff_shell: 2.0 +requires_triples: false +base_nbl: schnetpack.transform.MatScipyNeighborList +collate_fn: schnetpack.data.loader._atoms_collate_fn diff --git a/src/schnetpack/md/md_configs/calculator/spk.yaml b/src/schnetpack/md/md_configs/calculator/spk.yaml index 3299dae6b..78864b231 100644 --- a/src/schnetpack/md/md_configs/calculator/spk.yaml +++ b/src/schnetpack/md/md_configs/calculator/spk.yaml @@ -11,4 +11,4 @@ stress_key: null script_model: false defaults: - - neighbor_list: ase \ No newline at end of file + - neighbor_list: matscipy \ No newline at end of file diff --git a/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml b/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml index 642a0957e..3d4815736 100644 --- a/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml +++ b/src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml @@ -12,4 +12,4 @@ stress_key: null script_model: false defaults: - - neighbor_list: ase + - neighbor_list: matscipy diff --git a/src/schnetpack/transform/neighborlist.py b/src/schnetpack/transform/neighborlist.py index b8a735104..ddf5a1151 100644 --- a/src/schnetpack/transform/neighborlist.py +++ b/src/schnetpack/transform/neighborlist.py @@ -2,15 +2,16 @@ import torch import shutil from ase import Atoms -from ase.neighborlist import neighbor_list +from ase.neighborlist import neighbor_list as ase_neighbor_list +from matscipy.neighbours import neighbour_list as msp_neighbor_list from .base import Transform from dirsync import sync import numpy as np from typing import Optional, Dict, List, Type, Any, Union - __all__ = [ "ASENeighborList", + "MatScipyNeighborList", "TorchNeighborList", "CountNeighbors", "CollectAtomTriples", @@ -55,7 +56,6 @@ def forward( self, inputs: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: - inputs = self.neighbor_list(inputs) for postprocess in self.nbh_postprocessing: inputs = postprocess(inputs) @@ -229,7 +229,7 @@ class ASENeighborList(NeighborListTransform): def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff): at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc) - idx_i, idx_j, S = neighbor_list("ijS", at, cutoff, self_interaction=False) + idx_i, idx_j, S = ase_neighbor_list("ijS", at, cutoff, self_interaction=False) idx_i = torch.from_numpy(idx_i) idx_j = torch.from_numpy(idx_j) S = torch.from_numpy(S).to(dtype=positions.dtype) @@ -237,6 +237,35 @@ def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff): return idx_i, idx_j, offset +class MatScipyNeighborList(NeighborListTransform): + """ + Neighborlist using the efficient implementation of the Matscipy package + (https://github.com/libAtoms/matscipy). + """ + + def _build_neighbor_list( + self, Z, positions, cell, pbc, cutoff, eps=1e-6, buffer=1.0 + ): + at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc) + + # Add cell if none is present (volume = 0) + if at.cell.volume < eps: + # max values - min values along xyz augmented by small buffer for stability + new_cell = np.ptp(at.positions, axis=0) + buffer + # Set cell and center + at.set_cell(new_cell, scale_atoms=False) + at.center() + + # Compute neighborhood + idx_i, idx_j, S = msp_neighbor_list("ijS", at, cutoff) + idx_i = torch.from_numpy(idx_i).long() + idx_j = torch.from_numpy(idx_j).long() + S = torch.from_numpy(S).to(dtype=positions.dtype) + offset = torch.mm(S, cell) + + return idx_i, idx_j, offset + + class SkinNeighborList(Transform): """ Neighbor list provider utilizing a cutoff skin for computational efficiency. Wrapper around neighbor list classes