@@ -789,9 +789,42 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
789
789
pytest .param (0.0 ), # this should run no sanity checks
790
790
pytest .param (1 ),
791
791
pytest .param (1.0 ),
792
- pytest .param (0.3 ),
792
+ pytest .param (0.5 ),
793
+ pytest .param (5 ),
793
794
])
794
795
def test_num_sanity_val_steps (tmpdir , limit_val_batches ):
796
+ """
797
+ Test that num_sanity_val_steps!=-1 runs through all validation data once.
798
+ Makes sure the number of sanity check batches is clipped to limit_val_batches.
799
+ """
800
+ model = EvalModelTemplate ()
801
+ model .validation_step = model .validation_step__multiple_dataloaders
802
+ model .validation_epoch_end = model .validation_epoch_end__multiple_dataloaders
803
+ num_sanity_val_steps = 4
804
+
805
+ trainer = Trainer (
806
+ default_root_dir = tmpdir ,
807
+ num_sanity_val_steps = num_sanity_val_steps ,
808
+ limit_val_batches = limit_val_batches , # should have no influence
809
+ max_steps = 1 ,
810
+ )
811
+ assert trainer .num_sanity_val_steps == num_sanity_val_steps
812
+ val_dataloaders = model .val_dataloader__multiple_mixed_length ()
813
+
814
+ with patch .object (trainer , 'evaluation_forward' , wraps = trainer .evaluation_forward ) as mocked :
815
+ trainer .fit (model , val_dataloaders = val_dataloaders )
816
+ assert mocked .call_count == sum (
817
+ min (num_sanity_val_steps , num_batches ) for num_batches in trainer .num_val_batches
818
+ )
819
+
820
+
821
+ @pytest .mark .parametrize (['limit_val_batches' ], [
822
+ pytest .param (0.0 ), # this should run no sanity checks
823
+ pytest .param (1 ),
824
+ pytest .param (1.0 ),
825
+ pytest .param (0.3 ),
826
+ ])
827
+ def test_num_sanity_val_steps_neg_one (tmpdir , limit_val_batches ):
795
828
"""
796
829
Test that num_sanity_val_steps=-1 runs through all validation data once.
797
830
Makes sure the number of sanity check batches is clipped to limit_val_batches.
@@ -810,10 +843,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
810
843
811
844
with patch .object (trainer , 'evaluation_forward' , wraps = trainer .evaluation_forward ) as mocked :
812
845
trainer .fit (model , val_dataloaders = val_dataloaders )
813
- if isinstance (limit_val_batches , float ):
814
- assert mocked .call_count == sum (len (dl ) * limit_val_batches for dl in val_dataloaders )
815
- if isinstance (limit_val_batches , int ):
816
- assert mocked .call_count == sum (limit_val_batches for dl in val_dataloaders )
846
+ assert mocked .call_count == sum (trainer .num_val_batches )
817
847
818
848
819
849
@pytest .mark .parametrize ("trainer_kwargs,expected" , [
0 commit comments