From e81cc3be52de63952e4160b896374ecf517aaec7 Mon Sep 17 00:00:00 2001 From: ktschuett Date: Thu, 3 Nov 2022 14:05:41 +0100 Subject: [PATCH 1/2] Fix metrics computation, e.g. for RMSE --- src/schnetpack/task.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/schnetpack/task.py b/src/schnetpack/task.py index 6937ab731..2ae0f7ea6 100644 --- a/src/schnetpack/task.py +++ b/src/schnetpack/task.py @@ -61,12 +61,9 @@ def calculate_loss(self, pred, target): ) return loss - def calculate_metrics(self, pred, target): - metrics = { - metric_name: metric(pred[self.name], target[self.target_property]) - for metric_name, metric in self.metrics.items() - } - return metrics + def update_metrics(self, pred, target): + for metric in self.metrics.values(): + metric(pred[self.name], target[self.target_property]) class UnsupervisedModelOutput(ModelOutput): @@ -82,7 +79,7 @@ def calculate_loss(self, pred, target=None): loss = self.loss_weight * self.loss_fn(pred[self.name]) return loss - def calculate_metrics(self, pred, target=None): + def update_metrics(self, pred, target=None): metrics = { metric_name: metric(pred[self.name]) for metric_name, metric in self.metrics.items() @@ -149,7 +146,8 @@ def loss_fn(self, pred, batch): def log_metrics(self, pred, targets, subset): for output in self.outputs: - for metric_name, metric in output.calculate_metrics(pred, targets).items(): + output.update_metrics(pred, targets) + for metric_name, metric in output.metrics.items(): self.log( f"{subset}_{output.name}_{metric_name}", metric, From 8504ff5b91d79d7afb54c2b480864aa1f61ce4e8 Mon Sep 17 00:00:00 2001 From: ktschuett Date: Thu, 3 Nov 2022 14:10:08 +0100 Subject: [PATCH 2/2] Fix unsupervised outputs --- src/schnetpack/task.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/schnetpack/task.py b/src/schnetpack/task.py index 2ae0f7ea6..581250c00 100644 --- a/src/schnetpack/task.py +++ b/src/schnetpack/task.py @@ -80,11 +80,8 @@ def calculate_loss(self, pred, target=None): return loss def update_metrics(self, pred, target=None): - metrics = { - metric_name: metric(pred[self.name]) - for metric_name, metric in self.metrics.items() - } - return metrics + for metric in self.metrics.values(): + metric(pred[self.name]) class AtomisticTask(pl.LightningModule):