Skip to content

Commit 44847a4

Browse files
authored
made vector_representation in SO3net optional (#465)
1 parent 02bf00b commit 44847a4

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

Diff for: src/schnetpack/configs/model/representation/so3net.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ n_atom_basis: 128
77
n_interactions: 3
88
lmax: 2
99
shared_interactions: False
10+
return_vector_representation: False
1011
cutoff_fn:
1112
_target_: schnetpack.nn.cutoff.CosineCutoff
1213
cutoff: ${globals.cutoff}

Diff for: src/schnetpack/representation/so3net.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
cutoff_fn: Optional[Callable] = None,
2828
shared_interactions: bool = False,
2929
max_z: int = 100,
30+
return_vector_representation: bool = False,
3031
):
3132
"""
3233
Args:
@@ -39,6 +40,8 @@ def __init__(
3940
shared_interactions:
4041
max_z:
4142
conv_layer:
43+
return_vector_representation: return l=1 features in Cartesian XYZ order
44+
(e.g. for DipoleMoment output module)
4245
"""
4346
super(SO3net, self).__init__()
4447

@@ -48,6 +51,7 @@ def __init__(
4851
self.cutoff_fn = hydra.utils.instantiate(cutoff_fn)
4952
self.cutoff = cutoff_fn.cutoff
5053
self.radial_basis = hydra.utils.instantiate(radial_basis)
54+
self.return_vector_representation = return_vector_representation
5155

5256
self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0)
5357
self.sphharm = so3.RealSphericalHarmonics(lmax=lmax)
@@ -118,7 +122,10 @@ def forward(self, inputs: Dict[str, torch.Tensor]):
118122
x = x + dx
119123

120124
inputs["scalar_representation"] = x[:, 0]
121-
# extract cartesian vector from multipoles: [y, z, x] -> [x, y, z]
122-
inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1)
123125
inputs["multipole_representation"] = x
126+
127+
# extract cartesian vector from multipoles: [y, z, x] -> [x, y, z]
128+
if self.return_vector_representation:
129+
inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1)
130+
124131
return inputs

0 commit comments

Comments
 (0)