Skip to content

Add SO3 ops and layers #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/api/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,29 @@ Basic layers

Equivariant layers
------------------

Cartesian:

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

GatedEquivariantBlock

Irreps:

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

RealSphericalHarmonics
SO3TensorProduct
SO3Convolution
SO3GatedNonlinearity
SO3ParametricGatedNonlinearity


Radial basis
------------
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def read(fname):
python_requires=">=3.6",
install_requires=[
"numpy",
"sympy",
"ase>=3.21",
"h5py",
"pyyaml",
Expand Down
33 changes: 18 additions & 15 deletions src/schnetpack/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ class AtomisticModel(nn.Module):
"""
Base class for all SchNetPack models.

SchNetPack models should subclass `AtomisticModel` implement the forward method. To use the automatic collection of
required derivatives, each submodule that requires gradients w.r.t to the input, should list them as strings in
`submodule.required_derivatives = ["input_key"]`. The model needs to call `self.collect_derivatives()` at the end
of its `__init__`.
SchNetPack models should subclass `AtomisticModel` implement the forward method.
To use the automatic collection of required derivatives, each submodule that
requires gradients w.r.t to the input, should list them as strings in
`submodule.required_derivatives = ["input_key"]`. The model needs to call
`self.collect_derivatives()` at the end of its `__init__`.

To make use of post-processing transform, the model should call `input = self.postprocess(input)` at the end of
its `forward`. The post processors will only be applied if `do_postprocessing=True`.
To make use of post-processing transform, the model should call
`input = self.postprocess(input)` at the end of its `forward`. The post processors
will only be applied if `do_postprocessing=True`.

Example:
class SimpleModel(AtomisticModel):
Expand Down Expand Up @@ -126,11 +128,11 @@ def extract_outputs(

class NeuralNetworkPotential(AtomisticModel):
"""
A generic neural network potential class that sequentially applies a list of input modules, a representation module
and a list of output modules.
A generic neural network potential class that sequentially applies a list of input
modules, a representation module and a list of output modules.

This can be flexibly configured for various, e.g. property prediction or potential energy sufaces with response
properties.
This can be flexibly configured for various, e.g. property prediction or potential
energy sufaces with response properties.
"""

def __init__(
Expand All @@ -145,11 +147,12 @@ def __init__(
"""
Args:
representation: The module that builds representation from inputs.
input_modules: Modules that are applied before representation, e.g. to modify input or add additional tensors for response
properties.
output_modules: Modules that predict output properties from the representation.
postprocessors: Post-processing transforms that may be initialized using te `datamodule`, but are not
applied during training.
input_modules: Modules that are applied before representation, e.g. to
modify input or add additional tensors for response properties.
output_modules: Modules that predict output properties from the
representation.
postprocessors: Post-processing transforms that may be initialized using the
`datamodule`, but are not applied during training.
input_dtype_str: The dtype of real inputs.
do_postprocessing: If true, post-processing is activated.
"""
Expand Down
1 change: 1 addition & 0 deletions src/schnetpack/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from schnetpack.nn.blocks import *
from schnetpack.nn.cutoff import *
from schnetpack.nn.equivariant import *
from schnetpack.nn.so3 import *
from schnetpack.nn.scatter import *
from schnetpack.nn.radial import *
from schnetpack.nn.utils import *
Empty file.
10 changes: 10 additions & 0 deletions src/schnetpack/nn/ops/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch


def binom(n: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
"""
Compute binomial coefficients (n k)
"""
return torch.exp(
torch.lgamma(n + 1) - torch.lgamma((n - k) + 1) - torch.lgamma(k + 1)
)
140 changes: 140 additions & 0 deletions src/schnetpack/nn/ops/so3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import math
import torch
from sympy.physics.wigner import clebsch_gordan

from functools import lru_cache
from typing import Tuple


@lru_cache(maxsize=10)
def sh_indices(lmax: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Build index arrays for spherical harmonics

Args:
lmax: maximum angular momentum
"""
ls = torch.arange(0, lmax + 1)
nls = 2 * ls + 1
lidx = torch.repeat_interleave(ls, nls)
midx = torch.cat([torch.arange(-l, l + 1) for l in ls])
return lidx, midx


@lru_cache(maxsize=10)
def generate_sh_to_rsh(lmax: int) -> torch.Tensor:
"""
Generate transformation matrix to convert (complex) spherical harmonics to real form

Args:
lmax: maximum angular momentum
"""
lidx, midx = sh_indices(lmax)
l1 = lidx[:, None]
l2 = lidx[None, :]
m1 = midx[:, None]
m2 = midx[None, :]
U = (
1.0 * ((m1 == 0) * (m2 == 0))
+ (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == m2) * (m1 > 0))
+ 1.0 / math.sqrt(2) * ((m1 == -m2) * (m2 < 0))
+ -1.0j * (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == -m2) * (m1 < 0))
+ 1.0j / math.sqrt(2) * ((m1 == m2) * (m1 < 0))
) * (l1 == l2)
return U


@lru_cache(maxsize=10)
def generate_clebsch_gordan(lmax: int) -> torch.Tensor:
"""
Generate standard Clebsch-Gordan coefficients for complex spherical harmonics

Args:
lmax: maximum angular momentum
"""
lidx, midx = sh_indices(lmax)
cg = torch.zeros((lidx.shape[0], lidx.shape[0], lidx.shape[0]))
lidx = lidx.numpy()
midx = midx.numpy()
for c1, (l1, m1) in enumerate(zip(lidx, midx)):
for c2, (l2, m2) in enumerate(zip(lidx, midx)):
for c3, (l3, m3) in enumerate(zip(lidx, midx)):
if abs(l1 - l2) <= l3 <= min(l1 + l2, lmax) and m3 in {
m1 + m2,
m1 - m2,
m2 - m1,
-m1 - m2,
}:
coeff = clebsch_gordan(l1, l2, l3, m1, m2, m3)
cg[c1, c2, c3] = float(coeff)
return cg


@lru_cache(maxsize=10)
def generate_clebsch_gordan_rsh(
lmax: int, parity_invariance: bool = True
) -> torch.Tensor:
"""
Generate Clebsch-Gordan coefficients for real spherical harmonics

Args:
lmax: maximum angular momentum
parity_invariance: whether to enforce parity invariance, i.e. only allow
non-zero coefficients if :math:`-1^l_1 -1^l_2 = -1^l_3`

"""
lidx, _ = sh_indices(lmax)
cg = generate_clebsch_gordan(lmax).to(dtype=torch.complex64)
complex_to_real = generate_sh_to_rsh(lmax) # (real, complex)
cg_rsh = torch.einsum(
"ijk,mi,nj,ok->mno",
cg,
complex_to_real,
complex_to_real,
complex_to_real.conj(),
)

if parity_invariance:
parity = (-1.0) ** lidx
pmask = parity[:, None, None] * parity[None, :, None] == parity[None, None, :]
cg_rsh *= pmask
else:
lsum = lidx[:, None, None] + lidx[None, :, None] - lidx[None, None, :]
cg_rsh *= 1.0j**lsum

# cast to real
cg_rsh = cg_rsh.real.to(torch.float64)
return cg_rsh


def sparsify_clebsch_gordon(
cg: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert Clebsch-Gordon tensor to sparse format.

Args:
cg: dense tensor Clebsch-Gordon coefficients
[(lmax_1+1)^2, (lmax_2+1)^2, (lmax_out+1)^2]

Returns:
cg_sparse: vector of non-zeros CG coefficients
idx_in_1: indices for first set of irreps
idx_in_2: indices for second set of irreps
idx_out: indices for output set of irreps
"""
idx = torch.nonzero(cg)
idx_in_1, idx_in_2, idx_out = torch.split(idx, 1, dim=1)
idx_in_1, idx_in_2, idx_out = (
idx_in_1[:, 0],
idx_in_2[:, 0],
idx_out[:, 0],
)
cg_sparse = cg[idx_in_1, idx_in_2, idx_out]
return cg_sparse, idx_in_1, idx_in_2, idx_out


def round_cmp(x: torch.Tensor, decimals: int = 1):
return torch.round(x.real, decimals=decimals) + 1j * torch.round(
x.imag, decimals=decimals
)
Loading