@@ -27,6 +27,7 @@ def __init__(
27
27
cutoff_fn : Optional [Callable ] = None ,
28
28
shared_interactions : bool = False ,
29
29
max_z : int = 100 ,
30
+ return_vector_representation : bool = False ,
30
31
):
31
32
"""
32
33
Args:
@@ -39,6 +40,8 @@ def __init__(
39
40
shared_interactions:
40
41
max_z:
41
42
conv_layer:
43
+ return_vector_representation: return l=1 features in Cartesian XYZ order
44
+ (e.g. for DipoleMoment output module)
42
45
"""
43
46
super (SO3net , self ).__init__ ()
44
47
@@ -48,6 +51,7 @@ def __init__(
48
51
self .cutoff_fn = hydra .utils .instantiate (cutoff_fn )
49
52
self .cutoff = cutoff_fn .cutoff
50
53
self .radial_basis = hydra .utils .instantiate (radial_basis )
54
+ self .return_vector_representation = return_vector_representation
51
55
52
56
self .embedding = nn .Embedding (max_z , n_atom_basis , padding_idx = 0 )
53
57
self .sphharm = so3 .RealSphericalHarmonics (lmax = lmax )
@@ -118,7 +122,10 @@ def forward(self, inputs: Dict[str, torch.Tensor]):
118
122
x = x + dx
119
123
120
124
inputs ["scalar_representation" ] = x [:, 0 ]
121
- # extract cartesian vector from multipoles: [y, z, x] -> [x, y, z]
122
- inputs ["vector_representation" ] = torch .roll (x [:, 1 :4 ], 1 , 1 )
123
125
inputs ["multipole_representation" ] = x
126
+
127
+ # extract cartesian vector from multipoles: [y, z, x] -> [x, y, z]
128
+ if self .return_vector_representation :
129
+ inputs ["vector_representation" ] = torch .roll (x [:, 1 :4 ], 1 , 1 )
130
+
124
131
return inputs
0 commit comments