Skip to content

Commit 890bc18

Browse files
authored
Fix metrics computation (#455)
* Fix metrics computation, e.g. for RMSE * Fix unsupervised outputs
1 parent 21becbb commit 890bc18

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

Diff for: src/schnetpack/task.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,9 @@ def calculate_loss(self, pred, target):
6161
)
6262
return loss
6363

64-
def calculate_metrics(self, pred, target):
65-
metrics = {
66-
metric_name: metric(pred[self.name], target[self.target_property])
67-
for metric_name, metric in self.metrics.items()
68-
}
69-
return metrics
64+
def update_metrics(self, pred, target):
65+
for metric in self.metrics.values():
66+
metric(pred[self.name], target[self.target_property])
7067

7168

7269
class UnsupervisedModelOutput(ModelOutput):
@@ -82,12 +79,9 @@ def calculate_loss(self, pred, target=None):
8279
loss = self.loss_weight * self.loss_fn(pred[self.name])
8380
return loss
8481

85-
def calculate_metrics(self, pred, target=None):
86-
metrics = {
87-
metric_name: metric(pred[self.name])
88-
for metric_name, metric in self.metrics.items()
89-
}
90-
return metrics
82+
def update_metrics(self, pred, target=None):
83+
for metric in self.metrics.values():
84+
metric(pred[self.name])
9185

9286

9387
class AtomisticTask(pl.LightningModule):
@@ -149,7 +143,8 @@ def loss_fn(self, pred, batch):
149143

150144
def log_metrics(self, pred, targets, subset):
151145
for output in self.outputs:
152-
for metric_name, metric in output.calculate_metrics(pred, targets).items():
146+
output.update_metrics(pred, targets)
147+
for metric_name, metric in output.metrics.items():
153148
self.log(
154149
f"{subset}_{output.name}_{metric_name}",
155150
metric,

0 commit comments

Comments
 (0)