Skip to content

Commit 7edc56d

Browse files
authored
Added neighborlist from MatScipy package (#421)
* Added matscipy package to setup * Added transform for matscipy neighborlist * Changed neighborlist default to MatScipy * Added config for matscipy MD neighborlist * Changed MD neighborlist defaults to matscipy * Moved long conversion to MatScipyNeighborlist * Changed matscipy git to https
1 parent 6fe78ef commit 7edc56d

File tree

9 files changed

+46
-10
lines changed

9 files changed

+46
-10
lines changed

Diff for: setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def read(fname):
3535
"rich",
3636
"fasteners",
3737
"dirsync",
38+
"matscipy @ git+https://github.com/libAtoms/matscipy.git",
3839
],
3940
include_package_data=True,
4041
extras_require={"test": ["pytest", "pytest-datadir", "pytest-benchmark"]},

Diff for: src/schnetpack/configs/experiment/md17.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ data:
2222
- _target_: schnetpack.transform.RemoveOffsets
2323
property: energy
2424
remove_mean: True
25-
- _target_: schnetpack.transform.ASENeighborList
25+
- _target_: schnetpack.transform.MatScipyNeighborList
2626
cutoff: ${globals.cutoff}
2727
- _target_: schnetpack.transform.CastTo32
2828

Diff for: src/schnetpack/configs/experiment/qm9.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ data:
1818
property: ${globals.property}
1919
remove_atomrefs: True
2020
remove_mean: True
21-
- _target_: schnetpack.transform.ASENeighborList
21+
- _target_: schnetpack.transform.MatScipyNeighborList
2222
cutoff: ${globals.cutoff}
2323
- _target_: schnetpack.transform.CastTo32
2424

Diff for: src/schnetpack/configs/predict.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ data:
1515
datapath: ${datapath}
1616
transforms:
1717
- _target_: schnetpack.transform.SubtractCenterOfMass
18-
- _target_: schnetpack.transform.ASENeighborList
18+
- _target_: schnetpack.transform.MatScipyNeighborList
1919
cutoff: ${cutoff}
2020
- _target_: schnetpack.transform.CastTo32
2121

Diff for: src/schnetpack/md/md_configs/calculator/lj.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ stress_key: stress
99
healing_length: 4.0 #0.3405
1010

1111
defaults:
12-
- neighbor_list: ase
12+
- neighbor_list: matscipy
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_target_: schnetpack.md.neighborlist_md.NeighborListMD
2+
cutoff: ???
3+
cutoff_shell: 2.0
4+
requires_triples: false
5+
base_nbl: schnetpack.transform.MatScipyNeighborList
6+
collate_fn: schnetpack.data.loader._atoms_collate_fn

Diff for: src/schnetpack/md/md_configs/calculator/spk.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ stress_key: null
1111
script_model: false
1212

1313
defaults:
14-
- neighbor_list: ase
14+
- neighbor_list: matscipy

Diff for: src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ stress_key: null
1212
script_model: false
1313

1414
defaults:
15-
- neighbor_list: ase
15+
- neighbor_list: matscipy

Diff for: src/schnetpack/transform/neighborlist.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
import torch
33
import shutil
44
from ase import Atoms
5-
from ase.neighborlist import neighbor_list
5+
from ase.neighborlist import neighbor_list as ase_neighbor_list
6+
from matscipy.neighbours import neighbour_list as msp_neighbor_list
67
from .base import Transform
78
from dirsync import sync
89
import numpy as np
910
from typing import Optional, Dict, List, Type, Any, Union
1011

11-
1212
__all__ = [
1313
"ASENeighborList",
14+
"MatScipyNeighborList",
1415
"TorchNeighborList",
1516
"CountNeighbors",
1617
"CollectAtomTriples",
@@ -55,7 +56,6 @@ def forward(
5556
self,
5657
inputs: Dict[str, torch.Tensor],
5758
) -> Dict[str, torch.Tensor]:
58-
5959
inputs = self.neighbor_list(inputs)
6060
for postprocess in self.nbh_postprocessing:
6161
inputs = postprocess(inputs)
@@ -229,14 +229,43 @@ class ASENeighborList(NeighborListTransform):
229229
def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
230230
at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc)
231231

232-
idx_i, idx_j, S = neighbor_list("ijS", at, cutoff, self_interaction=False)
232+
idx_i, idx_j, S = ase_neighbor_list("ijS", at, cutoff, self_interaction=False)
233233
idx_i = torch.from_numpy(idx_i)
234234
idx_j = torch.from_numpy(idx_j)
235235
S = torch.from_numpy(S).to(dtype=positions.dtype)
236236
offset = torch.mm(S, cell)
237237
return idx_i, idx_j, offset
238238

239239

240+
class MatScipyNeighborList(NeighborListTransform):
241+
"""
242+
Neighborlist using the efficient implementation of the Matscipy package
243+
(https://github.com/libAtoms/matscipy).
244+
"""
245+
246+
def _build_neighbor_list(
247+
self, Z, positions, cell, pbc, cutoff, eps=1e-6, buffer=1.0
248+
):
249+
at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc)
250+
251+
# Add cell if none is present (volume = 0)
252+
if at.cell.volume < eps:
253+
# max values - min values along xyz augmented by small buffer for stability
254+
new_cell = np.ptp(at.positions, axis=0) + buffer
255+
# Set cell and center
256+
at.set_cell(new_cell, scale_atoms=False)
257+
at.center()
258+
259+
# Compute neighborhood
260+
idx_i, idx_j, S = msp_neighbor_list("ijS", at, cutoff)
261+
idx_i = torch.from_numpy(idx_i).long()
262+
idx_j = torch.from_numpy(idx_j).long()
263+
S = torch.from_numpy(S).to(dtype=positions.dtype)
264+
offset = torch.mm(S, cell)
265+
266+
return idx_i, idx_j, offset
267+
268+
240269
class SkinNeighborList(Transform):
241270
"""
242271
Neighbor list provider utilizing a cutoff skin for computational efficiency. Wrapper around neighbor list classes

0 commit comments

Comments
 (0)