Skip to content

Commit ed88ace

Browse files
authored
Merge pull request #445 from atomistic-machine-learning/kts/so3ops
Add SO3 ops and layers
2 parents 617c468 + 7a58503 commit ed88ace

File tree

8 files changed

+541
-15
lines changed

8 files changed

+541
-15
lines changed

Diff for: docs/api/nn.rst

+16
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,29 @@ Basic layers
1616

1717
Equivariant layers
1818
------------------
19+
20+
Cartesian:
21+
1922
.. autosummary::
2023
:toctree: generated
2124
:nosignatures:
2225
:template: classtemplate.rst
2326

2427
GatedEquivariantBlock
2528

29+
Irreps:
30+
31+
.. autosummary::
32+
:toctree: generated
33+
:nosignatures:
34+
:template: classtemplate.rst
35+
36+
RealSphericalHarmonics
37+
SO3TensorProduct
38+
SO3Convolution
39+
SO3GatedNonlinearity
40+
SO3ParametricGatedNonlinearity
41+
2642

2743
Radial basis
2844
------------

Diff for: setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def read(fname):
2525
python_requires=">=3.6",
2626
install_requires=[
2727
"numpy",
28+
"sympy",
2829
"ase>=3.21",
2930
"h5py",
3031
"pyyaml",

Diff for: src/schnetpack/model/base.py

+18-15
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ class AtomisticModel(nn.Module):
1616
"""
1717
Base class for all SchNetPack models.
1818
19-
SchNetPack models should subclass `AtomisticModel` implement the forward method. To use the automatic collection of
20-
required derivatives, each submodule that requires gradients w.r.t to the input, should list them as strings in
21-
`submodule.required_derivatives = ["input_key"]`. The model needs to call `self.collect_derivatives()` at the end
22-
of its `__init__`.
19+
SchNetPack models should subclass `AtomisticModel` implement the forward method.
20+
To use the automatic collection of required derivatives, each submodule that
21+
requires gradients w.r.t to the input, should list them as strings in
22+
`submodule.required_derivatives = ["input_key"]`. The model needs to call
23+
`self.collect_derivatives()` at the end of its `__init__`.
2324
24-
To make use of post-processing transform, the model should call `input = self.postprocess(input)` at the end of
25-
its `forward`. The post processors will only be applied if `do_postprocessing=True`.
25+
To make use of post-processing transform, the model should call
26+
`input = self.postprocess(input)` at the end of its `forward`. The post processors
27+
will only be applied if `do_postprocessing=True`.
2628
2729
Example:
2830
class SimpleModel(AtomisticModel):
@@ -126,11 +128,11 @@ def extract_outputs(
126128

127129
class NeuralNetworkPotential(AtomisticModel):
128130
"""
129-
A generic neural network potential class that sequentially applies a list of input modules, a representation module
130-
and a list of output modules.
131+
A generic neural network potential class that sequentially applies a list of input
132+
modules, a representation module and a list of output modules.
131133
132-
This can be flexibly configured for various, e.g. property prediction or potential energy sufaces with response
133-
properties.
134+
This can be flexibly configured for various, e.g. property prediction or potential
135+
energy sufaces with response properties.
134136
"""
135137

136138
def __init__(
@@ -145,11 +147,12 @@ def __init__(
145147
"""
146148
Args:
147149
representation: The module that builds representation from inputs.
148-
input_modules: Modules that are applied before representation, e.g. to modify input or add additional tensors for response
149-
properties.
150-
output_modules: Modules that predict output properties from the representation.
151-
postprocessors: Post-processing transforms that may be initialized using te `datamodule`, but are not
152-
applied during training.
150+
input_modules: Modules that are applied before representation, e.g. to
151+
modify input or add additional tensors for response properties.
152+
output_modules: Modules that predict output properties from the
153+
representation.
154+
postprocessors: Post-processing transforms that may be initialized using the
155+
`datamodule`, but are not applied during training.
153156
input_dtype_str: The dtype of real inputs.
154157
do_postprocessing: If true, post-processing is activated.
155158
"""

Diff for: src/schnetpack/nn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from schnetpack.nn.blocks import *
99
from schnetpack.nn.cutoff import *
1010
from schnetpack.nn.equivariant import *
11+
from schnetpack.nn.so3 import *
1112
from schnetpack.nn.scatter import *
1213
from schnetpack.nn.radial import *
1314
from schnetpack.nn.utils import *

Diff for: src/schnetpack/nn/ops/__init__.py

Whitespace-only changes.

Diff for: src/schnetpack/nn/ops/math.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import torch
2+
3+
4+
def binom(n: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
5+
"""
6+
Compute binomial coefficients (n k)
7+
"""
8+
return torch.exp(
9+
torch.lgamma(n + 1) - torch.lgamma((n - k) + 1) - torch.lgamma(k + 1)
10+
)

Diff for: src/schnetpack/nn/ops/so3.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import math
2+
import torch
3+
from sympy.physics.wigner import clebsch_gordan
4+
5+
from functools import lru_cache
6+
from typing import Tuple
7+
8+
9+
@lru_cache(maxsize=10)
10+
def sh_indices(lmax: int) -> Tuple[torch.Tensor, torch.Tensor]:
11+
"""
12+
Build index arrays for spherical harmonics
13+
14+
Args:
15+
lmax: maximum angular momentum
16+
"""
17+
ls = torch.arange(0, lmax + 1)
18+
nls = 2 * ls + 1
19+
lidx = torch.repeat_interleave(ls, nls)
20+
midx = torch.cat([torch.arange(-l, l + 1) for l in ls])
21+
return lidx, midx
22+
23+
24+
@lru_cache(maxsize=10)
25+
def generate_sh_to_rsh(lmax: int) -> torch.Tensor:
26+
"""
27+
Generate transformation matrix to convert (complex) spherical harmonics to real form
28+
29+
Args:
30+
lmax: maximum angular momentum
31+
"""
32+
lidx, midx = sh_indices(lmax)
33+
l1 = lidx[:, None]
34+
l2 = lidx[None, :]
35+
m1 = midx[:, None]
36+
m2 = midx[None, :]
37+
U = (
38+
1.0 * ((m1 == 0) * (m2 == 0))
39+
+ (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == m2) * (m1 > 0))
40+
+ 1.0 / math.sqrt(2) * ((m1 == -m2) * (m2 < 0))
41+
+ -1.0j * (-1.0) ** abs(m1) / math.sqrt(2) * ((m1 == -m2) * (m1 < 0))
42+
+ 1.0j / math.sqrt(2) * ((m1 == m2) * (m1 < 0))
43+
) * (l1 == l2)
44+
return U
45+
46+
47+
@lru_cache(maxsize=10)
48+
def generate_clebsch_gordan(lmax: int) -> torch.Tensor:
49+
"""
50+
Generate standard Clebsch-Gordan coefficients for complex spherical harmonics
51+
52+
Args:
53+
lmax: maximum angular momentum
54+
"""
55+
lidx, midx = sh_indices(lmax)
56+
cg = torch.zeros((lidx.shape[0], lidx.shape[0], lidx.shape[0]))
57+
lidx = lidx.numpy()
58+
midx = midx.numpy()
59+
for c1, (l1, m1) in enumerate(zip(lidx, midx)):
60+
for c2, (l2, m2) in enumerate(zip(lidx, midx)):
61+
for c3, (l3, m3) in enumerate(zip(lidx, midx)):
62+
if abs(l1 - l2) <= l3 <= min(l1 + l2, lmax) and m3 in {
63+
m1 + m2,
64+
m1 - m2,
65+
m2 - m1,
66+
-m1 - m2,
67+
}:
68+
coeff = clebsch_gordan(l1, l2, l3, m1, m2, m3)
69+
cg[c1, c2, c3] = float(coeff)
70+
return cg
71+
72+
73+
@lru_cache(maxsize=10)
74+
def generate_clebsch_gordan_rsh(
75+
lmax: int, parity_invariance: bool = True
76+
) -> torch.Tensor:
77+
"""
78+
Generate Clebsch-Gordan coefficients for real spherical harmonics
79+
80+
Args:
81+
lmax: maximum angular momentum
82+
parity_invariance: whether to enforce parity invariance, i.e. only allow
83+
non-zero coefficients if :math:`-1^l_1 -1^l_2 = -1^l_3`
84+
85+
"""
86+
lidx, _ = sh_indices(lmax)
87+
cg = generate_clebsch_gordan(lmax).to(dtype=torch.complex64)
88+
complex_to_real = generate_sh_to_rsh(lmax) # (real, complex)
89+
cg_rsh = torch.einsum(
90+
"ijk,mi,nj,ok->mno",
91+
cg,
92+
complex_to_real,
93+
complex_to_real,
94+
complex_to_real.conj(),
95+
)
96+
97+
if parity_invariance:
98+
parity = (-1.0) ** lidx
99+
pmask = parity[:, None, None] * parity[None, :, None] == parity[None, None, :]
100+
cg_rsh *= pmask
101+
else:
102+
lsum = lidx[:, None, None] + lidx[None, :, None] - lidx[None, None, :]
103+
cg_rsh *= 1.0j**lsum
104+
105+
# cast to real
106+
cg_rsh = cg_rsh.real.to(torch.float64)
107+
return cg_rsh
108+
109+
110+
def sparsify_clebsch_gordon(
111+
cg: torch.Tensor,
112+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
113+
"""
114+
Convert Clebsch-Gordon tensor to sparse format.
115+
116+
Args:
117+
cg: dense tensor Clebsch-Gordon coefficients
118+
[(lmax_1+1)^2, (lmax_2+1)^2, (lmax_out+1)^2]
119+
120+
Returns:
121+
cg_sparse: vector of non-zeros CG coefficients
122+
idx_in_1: indices for first set of irreps
123+
idx_in_2: indices for second set of irreps
124+
idx_out: indices for output set of irreps
125+
"""
126+
idx = torch.nonzero(cg)
127+
idx_in_1, idx_in_2, idx_out = torch.split(idx, 1, dim=1)
128+
idx_in_1, idx_in_2, idx_out = (
129+
idx_in_1[:, 0],
130+
idx_in_2[:, 0],
131+
idx_out[:, 0],
132+
)
133+
cg_sparse = cg[idx_in_1, idx_in_2, idx_out]
134+
return cg_sparse, idx_in_1, idx_in_2, idx_out
135+
136+
137+
def round_cmp(x: torch.Tensor, decimals: int = 1):
138+
return torch.round(x.real, decimals=decimals) + 1j * torch.round(
139+
x.imag, decimals=decimals
140+
)

0 commit comments

Comments
 (0)