@@ -214,12 +214,15 @@ def test_val_step_only_step_metrics(tmpdir):
214
214
215
215
# make sure we logged the correct epoch metrics
216
216
total_empty_epoch_metrics = 0
217
+ epoch = 0
217
218
for metric in trainer .dev_debugger .logged_metrics :
219
+ if 'epoch' in metric :
220
+ epoch += 1
218
221
if len (metric ) > 2 :
219
222
assert 'no_val_no_pbar' not in metric
220
223
assert 'val_step_pbar_acc' not in metric
221
- assert metric ['val_step_log_acc' ]
222
- assert metric ['val_step_log_pbar_acc' ]
224
+ assert metric [f 'val_step_log_acc/epoch_ { epoch } ' ]
225
+ assert metric [f 'val_step_log_pbar_acc/epoch_ { epoch } ' ]
223
226
else :
224
227
total_empty_epoch_metrics += 1
225
228
@@ -228,6 +231,8 @@ def test_val_step_only_step_metrics(tmpdir):
228
231
# make sure we logged the correct epoch pbar metrics
229
232
total_empty_epoch_metrics = 0
230
233
for metric in trainer .dev_debugger .pbar_added_metrics :
234
+ if 'epoch' in metric :
235
+ epoch += 1
231
236
if len (metric ) > 2 :
232
237
assert 'no_val_no_pbar' not in metric
233
238
assert 'val_step_log_acc' not in metric
@@ -288,11 +293,12 @@ def test_val_step_epoch_step_metrics(tmpdir):
288
293
for metric_idx in range (0 , len (trainer .dev_debugger .logged_metrics ), batches + 1 ):
289
294
batch_metrics = trainer .dev_debugger .logged_metrics [metric_idx : metric_idx + batches ]
290
295
epoch_metric = trainer .dev_debugger .logged_metrics [metric_idx + batches ]
296
+ epoch = epoch_metric ['epoch' ]
291
297
292
298
# make sure the metric was split
293
299
for batch_metric in batch_metrics :
294
- assert 'step_val_step_log_acc' in batch_metric
295
- assert 'step_val_step_log_pbar_acc' in batch_metric
300
+ assert f 'step_val_step_log_acc/epoch_ { epoch } ' in batch_metric
301
+ assert f 'step_val_step_log_pbar_acc/epoch_ { epoch } ' in batch_metric
296
302
297
303
# make sure the epoch split was correct
298
304
assert 'epoch_val_step_log_acc' in epoch_metric
@@ -421,11 +427,11 @@ def test_val_step_full_loop_result_dp(tmpdir):
421
427
assert 'train_step_metric' in seen_keys
422
428
assert 'train_step_end_metric' in seen_keys
423
429
assert 'epoch_train_epoch_end_metric' in seen_keys
424
- assert 'step_validation_step_metric' in seen_keys
430
+ assert 'step_validation_step_metric/epoch_0 ' in seen_keys
425
431
assert 'epoch_validation_step_metric' in seen_keys
426
432
assert 'validation_step_end_metric' in seen_keys
427
433
assert 'validation_epoch_end_metric' in seen_keys
428
- assert 'step_test_step_metric' in seen_keys
434
+ assert 'step_test_step_metric/epoch_2 ' in seen_keys
429
435
assert 'epoch_test_step_metric' in seen_keys
430
436
assert 'test_step_end_metric' in seen_keys
431
437
assert 'test_epoch_end_metric' in seen_keys
0 commit comments