Skip to content

Commit 0da220c

Browse files
committed
fixes slurm weights saving
1 parent 848eb81 commit 0da220c

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/trainer/test_trainer_steps.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def test_trainingstep_dict(tmpdir):
2424

2525
out = trainer.run_training_batch(batch, batch_idx)
2626
assert out.signal == 0
27-
assert out.all_log_metrics['log_acc1'] == 12.0
28-
assert out.all_log_metrics['log_acc2'] == 7.0
27+
assert out.batch_log_metrics['log_acc1'] == 12.0
28+
assert out.batch_log_metrics['log_acc2'] == 7.0
2929

3030
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
3131
assert pbar_metrics['pbar_acc1'] == 17.0
@@ -55,8 +55,8 @@ def training_step_with_step_end(tmpdir):
5555

5656
out = trainer.run_training_batch(batch, batch_idx)
5757
assert out.signal == 0
58-
assert out.all_log_metrics['log_acc1'] == 12.0
59-
assert out.all_log_metrics['log_acc2'] == 7.0
58+
assert out.batch_log_metrics['log_acc1'] == 12.0
59+
assert out.batch_log_metrics['log_acc2'] == 7.0
6060

6161
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
6262
assert pbar_metrics['pbar_acc1'] == 17.0
@@ -91,8 +91,8 @@ def test_full_training_loop_dict(tmpdir):
9191

9292
out = trainer.run_training_batch(batch, batch_idx)
9393
assert out.signal == 0
94-
assert out.all_log_metrics['log_acc1'] == 12.0
95-
assert out.all_log_metrics['log_acc2'] == 7.0
94+
assert out.batch_log_metrics['log_acc1'] == 12.0
95+
assert out.batch_log_metrics['log_acc2'] == 7.0
9696

9797
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
9898
assert pbar_metrics['pbar_acc1'] == 17.0
@@ -127,8 +127,8 @@ def test_train_step_epoch_end(tmpdir):
127127

128128
out = trainer.run_training_batch(batch, batch_idx)
129129
assert out.signal == 0
130-
assert out.all_log_metrics['log_acc1'] == 12.0
131-
assert out.all_log_metrics['log_acc2'] == 7.0
130+
assert out.batch_log_metrics['log_acc1'] == 12.0
131+
assert out.batch_log_metrics['log_acc2'] == 7.0
132132

133133
pbar_metrics = out.training_step_output_for_epoch_end['pbar_on_batch_end']
134134
assert pbar_metrics['pbar_acc1'] == 17.0

0 commit comments

Comments
 (0)