Skip to content

Added neighborlist from MatScipy package #421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def read(fname):
"rich",
"fasteners",
"dirsync",
"matscipy @ git+https://github.com/libAtoms/matscipy.git",
],
include_package_data=True,
extras_require={"test": ["pytest", "pytest-datadir", "pytest-benchmark"]},
Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/configs/experiment/md17.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ data:
- _target_: schnetpack.transform.RemoveOffsets
property: energy
remove_mean: True
- _target_: schnetpack.transform.ASENeighborList
- _target_: schnetpack.transform.MatScipyNeighborList
cutoff: ${globals.cutoff}
- _target_: schnetpack.transform.CastTo32

Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/configs/experiment/qm9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ data:
property: ${globals.property}
remove_atomrefs: True
remove_mean: True
- _target_: schnetpack.transform.ASENeighborList
- _target_: schnetpack.transform.MatScipyNeighborList
cutoff: ${globals.cutoff}
- _target_: schnetpack.transform.CastTo32

Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/configs/predict.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ data:
datapath: ${datapath}
transforms:
- _target_: schnetpack.transform.SubtractCenterOfMass
- _target_: schnetpack.transform.ASENeighborList
- _target_: schnetpack.transform.MatScipyNeighborList
cutoff: ${cutoff}
- _target_: schnetpack.transform.CastTo32

Expand Down
2 changes: 1 addition & 1 deletion src/schnetpack/md/md_configs/calculator/lj.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ stress_key: stress
healing_length: 4.0 #0.3405

defaults:
- neighbor_list: ase
- neighbor_list: matscipy
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: schnetpack.md.neighborlist_md.NeighborListMD
cutoff: ???
cutoff_shell: 2.0
requires_triples: false
base_nbl: schnetpack.transform.MatScipyNeighborList
collate_fn: schnetpack.data.loader._atoms_collate_fn
2 changes: 1 addition & 1 deletion src/schnetpack/md/md_configs/calculator/spk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ stress_key: null
script_model: false

defaults:
- neighbor_list: ase
- neighbor_list: matscipy
2 changes: 1 addition & 1 deletion src/schnetpack/md/md_configs/calculator/spk_ensemble.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ stress_key: null
script_model: false

defaults:
- neighbor_list: ase
- neighbor_list: matscipy
37 changes: 33 additions & 4 deletions src/schnetpack/transform/neighborlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import torch
import shutil
from ase import Atoms
from ase.neighborlist import neighbor_list
from ase.neighborlist import neighbor_list as ase_neighbor_list
from matscipy.neighbours import neighbour_list as msp_neighbor_list
from .base import Transform
from dirsync import sync
import numpy as np
from typing import Optional, Dict, List, Type, Any, Union


__all__ = [
"ASENeighborList",
"MatScipyNeighborList",
"TorchNeighborList",
"CountNeighbors",
"CollectAtomTriples",
Expand Down Expand Up @@ -55,7 +56,6 @@ def forward(
self,
inputs: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:

inputs = self.neighbor_list(inputs)
for postprocess in self.nbh_postprocessing:
inputs = postprocess(inputs)
Expand Down Expand Up @@ -229,14 +229,43 @@ class ASENeighborList(NeighborListTransform):
def _build_neighbor_list(self, Z, positions, cell, pbc, cutoff):
at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc)

idx_i, idx_j, S = neighbor_list("ijS", at, cutoff, self_interaction=False)
idx_i, idx_j, S = ase_neighbor_list("ijS", at, cutoff, self_interaction=False)
idx_i = torch.from_numpy(idx_i)
idx_j = torch.from_numpy(idx_j)
S = torch.from_numpy(S).to(dtype=positions.dtype)
offset = torch.mm(S, cell)
return idx_i, idx_j, offset


class MatScipyNeighborList(NeighborListTransform):
"""
Neighborlist using the efficient implementation of the Matscipy package
(https://github.com/libAtoms/matscipy).
"""

def _build_neighbor_list(
self, Z, positions, cell, pbc, cutoff, eps=1e-6, buffer=1.0
):
at = Atoms(numbers=Z, positions=positions, cell=cell, pbc=pbc)

# Add cell if none is present (volume = 0)
if at.cell.volume < eps:
# max values - min values along xyz augmented by small buffer for stability
new_cell = np.ptp(at.positions, axis=0) + buffer
# Set cell and center
at.set_cell(new_cell, scale_atoms=False)
at.center()

# Compute neighborhood
idx_i, idx_j, S = msp_neighbor_list("ijS", at, cutoff)
idx_i = torch.from_numpy(idx_i).long()
idx_j = torch.from_numpy(idx_j).long()
S = torch.from_numpy(S).to(dtype=positions.dtype)
offset = torch.mm(S, cell)

return idx_i, idx_j, offset


class SkinNeighborList(Transform):
"""
Neighbor list provider utilizing a cutoff skin for computational efficiency. Wrapper around neighbor list classes
Expand Down