@@ -53,19 +53,15 @@ def test_fit_val_loader_only(tmpdir):
53
53
54
54
55
55
@pytest .mark .parametrize ("dataloader_options" , [
56
- dict (val_check_interval = 1.1 ),
57
56
dict (val_check_interval = 10000 ),
58
57
])
59
58
def test_dataloader_config_errors_runtime (tmpdir , dataloader_options ):
60
-
61
59
model = EvalModelTemplate ()
62
-
63
60
trainer = Trainer (
64
61
default_root_dir = tmpdir ,
65
62
max_epochs = 1 ,
66
63
** dataloader_options ,
67
64
)
68
-
69
65
with pytest .raises (ValueError ):
70
66
# fit model
71
67
trainer .fit (model )
@@ -78,9 +74,13 @@ def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):
78
74
dict (limit_val_batches = 1.2 ),
79
75
dict (limit_test_batches = - 0.1 ),
80
76
dict (limit_test_batches = 1.2 ),
77
+ dict (val_check_interval = - 0.1 ),
78
+ dict (val_check_interval = 1.2 ),
79
+ dict (overfit_batches = - 0.1 ),
80
+ dict (overfit_batches = 1.2 ),
81
81
])
82
82
def test_dataloader_config_errors_init (tmpdir , dataloader_options ):
83
- with pytest .raises (MisconfigurationException ):
83
+ with pytest .raises (MisconfigurationException , match = 'passed invalid value' ):
84
84
Trainer (
85
85
default_root_dir = tmpdir ,
86
86
max_epochs = 1 ,
@@ -256,6 +256,62 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
256
256
f'Multiple `test_dataloaders` not initiated properly, got { trainer .test_dataloaders } '
257
257
258
258
259
+ @pytest .mark .parametrize (['limit_train_batches' , 'limit_val_batches' , 'limit_test_batches' ], [
260
+ pytest .param (0.0 , 0.0 , 0.0 ),
261
+ pytest .param (1.0 , 1.0 , 1.0 ),
262
+ ])
263
+ def test_inf_dataloaders_with_limit_percent_batches (tmpdir , limit_train_batches , limit_val_batches , limit_test_batches ):
264
+ """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
265
+ model = EvalModelTemplate ()
266
+ model .train_dataloader = model .train_dataloader__infinite
267
+ model .val_dataloader = model .val_dataloader__infinite
268
+ model .test_dataloader = model .test_dataloader__infinite
269
+
270
+ trainer = Trainer (
271
+ default_root_dir = tmpdir ,
272
+ max_epochs = 1 ,
273
+ limit_train_batches = limit_train_batches ,
274
+ limit_val_batches = limit_val_batches ,
275
+ limit_test_batches = limit_test_batches ,
276
+ )
277
+
278
+ results = trainer .fit (model )
279
+ assert results == 1
280
+ assert trainer .num_training_batches == (0 if limit_train_batches == 0.0 else float ('inf' ))
281
+ assert trainer .num_val_batches [0 ] == (0 if limit_val_batches == 0.0 else float ('inf' ))
282
+
283
+ trainer .test (ckpt_path = None )
284
+ assert trainer .num_test_batches [0 ] == (0 if limit_test_batches == 0.0 else float ('inf' ))
285
+
286
+
287
+ @pytest .mark .parametrize (['limit_train_batches' , 'limit_val_batches' , 'limit_test_batches' ], [
288
+ pytest .param (0 , 0 , 0 ),
289
+ pytest .param (10 , 10 , 10 ),
290
+ ])
291
+ def test_inf_dataloaders_with_limit_num_batches (tmpdir , limit_train_batches , limit_val_batches , limit_test_batches ):
292
+ """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
293
+ model = EvalModelTemplate ()
294
+ model .train_dataloader = model .train_dataloader__infinite
295
+ model .val_dataloader = model .val_dataloader__infinite
296
+ model .test_dataloader = model .test_dataloader__infinite
297
+
298
+ trainer = Trainer (
299
+ default_root_dir = tmpdir ,
300
+ max_epochs = 1 ,
301
+ limit_train_batches = limit_train_batches ,
302
+ limit_val_batches = limit_val_batches ,
303
+ limit_test_batches = limit_test_batches ,
304
+ )
305
+
306
+ results = trainer .fit (model )
307
+ assert results
308
+ assert trainer .num_training_batches == limit_train_batches
309
+ assert trainer .num_val_batches [0 ] == limit_val_batches
310
+
311
+ trainer .test (ckpt_path = None )
312
+ assert trainer .num_test_batches [0 ] == limit_test_batches
313
+
314
+
259
315
@pytest .mark .parametrize (
260
316
['limit_train_batches' , 'limit_val_batches' , 'limit_test_batches' ],
261
317
[
@@ -266,7 +322,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
266
322
]
267
323
)
268
324
def test_dataloaders_with_limit_percent_batches (tmpdir , limit_train_batches , limit_val_batches , limit_test_batches ):
269
- """Verify num_batches for val & test dataloaders passed with batch limit in percent"""
325
+ """Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
270
326
model = EvalModelTemplate ()
271
327
model .val_dataloader = model .val_dataloader__multiple_mixed_length
272
328
model .test_dataloader = model .test_dataloader__multiple_mixed_length
@@ -307,7 +363,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
307
363
]
308
364
)
309
365
def test_dataloaders_with_limit_num_batches (tmpdir , limit_train_batches , limit_val_batches , limit_test_batches ):
310
- """Verify num_batches for val & test dataloaders passed with batch limit as number"""
366
+ """Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
311
367
os .environ ['PL_DEV_DEBUG' ] = '1'
312
368
313
369
model = EvalModelTemplate ()
@@ -436,7 +492,7 @@ def test_train_inf_dataloader_error(tmpdir):
436
492
437
493
trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , val_check_interval = 0.5 )
438
494
439
- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
495
+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
440
496
trainer .fit (model )
441
497
442
498
@@ -447,7 +503,7 @@ def test_val_inf_dataloader_error(tmpdir):
447
503
448
504
trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , limit_val_batches = 0.5 )
449
505
450
- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
506
+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
451
507
trainer .fit (model )
452
508
453
509
@@ -458,7 +514,7 @@ def test_test_inf_dataloader_error(tmpdir):
458
514
459
515
trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , limit_test_batches = 0.5 )
460
516
461
- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
517
+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
462
518
trainer .test (model )
463
519
464
520
@@ -774,7 +830,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):
774
830
775
831
trainer = Trainer (default_root_dir = tmpdir , max_steps = 5 , max_epochs = 1 , val_check_interval = 0.5 )
776
832
777
- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
833
+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
778
834
trainer .fit (model )
779
835
780
836
@@ -785,7 +841,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):
785
841
786
842
trainer = Trainer (default_root_dir = tmpdir , max_steps = 5 , max_epochs = 1 , limit_val_batches = 0.5 )
787
843
788
- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
844
+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
789
845
trainer .fit (model )
790
846
791
847
@@ -796,5 +852,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):
796
852
797
853
trainer = Trainer (default_root_dir = tmpdir , max_steps = 5 , max_epochs = 1 , limit_test_batches = 0.5 )
798
854
799
- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
855
+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
800
856
trainer .test (model )
0 commit comments