@@ -61,12 +61,9 @@ def calculate_loss(self, pred, target):
61
61
)
62
62
return loss
63
63
64
- def calculate_metrics (self , pred , target ):
65
- metrics = {
66
- metric_name : metric (pred [self .name ], target [self .target_property ])
67
- for metric_name , metric in self .metrics .items ()
68
- }
69
- return metrics
64
+ def update_metrics (self , pred , target ):
65
+ for metric in self .metrics .values ():
66
+ metric (pred [self .name ], target [self .target_property ])
70
67
71
68
72
69
class UnsupervisedModelOutput (ModelOutput ):
@@ -82,12 +79,9 @@ def calculate_loss(self, pred, target=None):
82
79
loss = self .loss_weight * self .loss_fn (pred [self .name ])
83
80
return loss
84
81
85
- def calculate_metrics (self , pred , target = None ):
86
- metrics = {
87
- metric_name : metric (pred [self .name ])
88
- for metric_name , metric in self .metrics .items ()
89
- }
90
- return metrics
82
+ def update_metrics (self , pred , target = None ):
83
+ for metric in self .metrics .values ():
84
+ metric (pred [self .name ])
91
85
92
86
93
87
class AtomisticTask (pl .LightningModule ):
@@ -149,7 +143,8 @@ def loss_fn(self, pred, batch):
149
143
150
144
def log_metrics (self , pred , targets , subset ):
151
145
for output in self .outputs :
152
- for metric_name , metric in output .calculate_metrics (pred , targets ).items ():
146
+ output .update_metrics (pred , targets )
147
+ for metric_name , metric in output .metrics .items ():
153
148
self .log (
154
149
f"{ subset } _{ output .name } _{ metric_name } " ,
155
150
metric ,
0 commit comments