diff --git a/src/schnetpack/task.py b/src/schnetpack/task.py index 6937ab731..581250c00 100644 --- a/src/schnetpack/task.py +++ b/src/schnetpack/task.py @@ -61,12 +61,9 @@ def calculate_loss(self, pred, target): ) return loss - def calculate_metrics(self, pred, target): - metrics = { - metric_name: metric(pred[self.name], target[self.target_property]) - for metric_name, metric in self.metrics.items() - } - return metrics + def update_metrics(self, pred, target): + for metric in self.metrics.values(): + metric(pred[self.name], target[self.target_property]) class UnsupervisedModelOutput(ModelOutput): @@ -82,12 +79,9 @@ def calculate_loss(self, pred, target=None): loss = self.loss_weight * self.loss_fn(pred[self.name]) return loss - def calculate_metrics(self, pred, target=None): - metrics = { - metric_name: metric(pred[self.name]) - for metric_name, metric in self.metrics.items() - } - return metrics + def update_metrics(self, pred, target=None): + for metric in self.metrics.values(): + metric(pred[self.name]) class AtomisticTask(pl.LightningModule): @@ -149,7 +143,8 @@ def loss_fn(self, pred, batch): def log_metrics(self, pred, targets, subset): for output in self.outputs: - for metric_name, metric in output.calculate_metrics(pred, targets).items(): + output.update_metrics(pred, targets) + for metric_name, metric in output.metrics.items(): self.log( f"{subset}_{output.name}_{metric_name}", metric,