Skip to content

Commit 4bee960

Browse files
committed
add more test
1 parent 6e084d1 commit 4bee960

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

tests/trainer/test_trainer.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -789,9 +789,42 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
789789
pytest.param(0.0), # this should run no sanity checks
790790
pytest.param(1),
791791
pytest.param(1.0),
792-
pytest.param(0.3),
792+
pytest.param(0.5),
793+
pytest.param(5),
793794
])
794795
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
796+
"""
797+
Test that num_sanity_val_steps!=-1 runs through all validation data once.
798+
Makes sure the number of sanity check batches is clipped to limit_val_batches.
799+
"""
800+
model = EvalModelTemplate()
801+
model.validation_step = model.validation_step__multiple_dataloaders
802+
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
803+
num_sanity_val_steps = 4
804+
805+
trainer = Trainer(
806+
default_root_dir=tmpdir,
807+
num_sanity_val_steps=num_sanity_val_steps,
808+
limit_val_batches=limit_val_batches, # should have no influence
809+
max_steps=1,
810+
)
811+
assert trainer.num_sanity_val_steps == num_sanity_val_steps
812+
val_dataloaders = model.val_dataloader__multiple_mixed_length()
813+
814+
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
815+
trainer.fit(model, val_dataloaders=val_dataloaders)
816+
assert mocked.call_count == sum(
817+
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
818+
)
819+
820+
821+
@pytest.mark.parametrize(['limit_val_batches'], [
822+
pytest.param(0.0), # this should run no sanity checks
823+
pytest.param(1),
824+
pytest.param(1.0),
825+
pytest.param(0.3),
826+
])
827+
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
795828
"""
796829
Test that num_sanity_val_steps=-1 runs through all validation data once.
797830
Makes sure the number of sanity check batches is clipped to limit_val_batches.
@@ -810,10 +843,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
810843

811844
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
812845
trainer.fit(model, val_dataloaders=val_dataloaders)
813-
if isinstance(limit_val_batches, float):
814-
assert mocked.call_count == sum(len(dl) * limit_val_batches for dl in val_dataloaders)
815-
if isinstance(limit_val_batches, int):
816-
assert mocked.call_count == sum(limit_val_batches for dl in val_dataloaders)
846+
assert mocked.call_count == sum(trainer.num_val_batches)
817847

818848

819849
@pytest.mark.parametrize("trainer_kwargs,expected", [

0 commit comments

Comments
 (0)