@@ -24,8 +24,8 @@ def test_trainingstep_dict(tmpdir):
24
24
25
25
out = trainer .run_training_batch (batch , batch_idx )
26
26
assert out .signal == 0
27
- assert out .all_log_metrics ['log_acc1' ] == 12.0
28
- assert out .all_log_metrics ['log_acc2' ] == 7.0
27
+ assert out .batch_log_metrics ['log_acc1' ] == 12.0
28
+ assert out .batch_log_metrics ['log_acc2' ] == 7.0
29
29
30
30
pbar_metrics = out .training_step_output_for_epoch_end ['pbar_on_batch_end' ]
31
31
assert pbar_metrics ['pbar_acc1' ] == 17.0
@@ -55,8 +55,8 @@ def training_step_with_step_end(tmpdir):
55
55
56
56
out = trainer .run_training_batch (batch , batch_idx )
57
57
assert out .signal == 0
58
- assert out .all_log_metrics ['log_acc1' ] == 12.0
59
- assert out .all_log_metrics ['log_acc2' ] == 7.0
58
+ assert out .batch_log_metrics ['log_acc1' ] == 12.0
59
+ assert out .batch_log_metrics ['log_acc2' ] == 7.0
60
60
61
61
pbar_metrics = out .training_step_output_for_epoch_end ['pbar_on_batch_end' ]
62
62
assert pbar_metrics ['pbar_acc1' ] == 17.0
@@ -91,8 +91,8 @@ def test_full_training_loop_dict(tmpdir):
91
91
92
92
out = trainer .run_training_batch (batch , batch_idx )
93
93
assert out .signal == 0
94
- assert out .all_log_metrics ['log_acc1' ] == 12.0
95
- assert out .all_log_metrics ['log_acc2' ] == 7.0
94
+ assert out .batch_log_metrics ['log_acc1' ] == 12.0
95
+ assert out .batch_log_metrics ['log_acc2' ] == 7.0
96
96
97
97
pbar_metrics = out .training_step_output_for_epoch_end ['pbar_on_batch_end' ]
98
98
assert pbar_metrics ['pbar_acc1' ] == 17.0
@@ -127,8 +127,8 @@ def test_train_step_epoch_end(tmpdir):
127
127
128
128
out = trainer .run_training_batch (batch , batch_idx )
129
129
assert out .signal == 0
130
- assert out .all_log_metrics ['log_acc1' ] == 12.0
131
- assert out .all_log_metrics ['log_acc2' ] == 7.0
130
+ assert out .batch_log_metrics ['log_acc1' ] == 12.0
131
+ assert out .batch_log_metrics ['log_acc2' ] == 7.0
132
132
133
133
pbar_metrics = out .training_step_output_for_epoch_end ['pbar_on_batch_end' ]
134
134
assert pbar_metrics ['pbar_acc1' ] == 17.0
0 commit comments