From ad05f206da5ebee273817dddaef4f860b6e80673 Mon Sep 17 00:00:00 2001 From: mgastegger Date: Thu, 1 Dec 2022 10:52:26 +0100 Subject: [PATCH] made vector_representation in SO3net optional --- .../configs/model/representation/so3net.yaml | 1 + src/schnetpack/representation/so3net.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/schnetpack/configs/model/representation/so3net.yaml b/src/schnetpack/configs/model/representation/so3net.yaml index 14ff4ec2c..8c8089c83 100644 --- a/src/schnetpack/configs/model/representation/so3net.yaml +++ b/src/schnetpack/configs/model/representation/so3net.yaml @@ -7,6 +7,7 @@ n_atom_basis: 128 n_interactions: 3 lmax: 2 shared_interactions: False +return_vector_representation: False cutoff_fn: _target_: schnetpack.nn.cutoff.CosineCutoff cutoff: ${globals.cutoff} \ No newline at end of file diff --git a/src/schnetpack/representation/so3net.py b/src/schnetpack/representation/so3net.py index 153837be8..be10fcdd9 100644 --- a/src/schnetpack/representation/so3net.py +++ b/src/schnetpack/representation/so3net.py @@ -27,6 +27,7 @@ def __init__( cutoff_fn: Optional[Callable] = None, shared_interactions: bool = False, max_z: int = 100, + return_vector_representation: bool = False, ): """ Args: @@ -39,6 +40,8 @@ def __init__( shared_interactions: max_z: conv_layer: + return_vector_representation: return l=1 features in Cartesian XYZ order + (e.g. for DipoleMoment output module) """ super(SO3net, self).__init__() @@ -48,6 +51,7 @@ def __init__( self.cutoff_fn = hydra.utils.instantiate(cutoff_fn) self.cutoff = cutoff_fn.cutoff self.radial_basis = hydra.utils.instantiate(radial_basis) + self.return_vector_representation = return_vector_representation self.embedding = nn.Embedding(max_z, n_atom_basis, padding_idx=0) self.sphharm = so3.RealSphericalHarmonics(lmax=lmax) @@ -118,7 +122,10 @@ def forward(self, inputs: Dict[str, torch.Tensor]): x = x + dx inputs["scalar_representation"] = x[:, 0] - # extract cartesian vector from multipoles: [y, z, x] -> [x, y, z] - inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1) inputs["multipole_representation"] = x + + # extract cartesian vector from multipoles: [y, z, x] -> [x, y, z] + if self.return_vector_representation: + inputs["vector_representation"] = torch.roll(x[:, 1:4], 1, 1) + return inputs