Skip to content

Commit a642349

Browse files
rohitgr7Bordaawaelchliethanwharris
authored
Support limit_mode_batches (int) for infinite dataloader (#2840)
* Support limit_mode_batches(int) for infinite dataloader * flake8 * revert and update * add and update tests * pep8 * chlog * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <[email protected]> * Add suggestions by @awaelchli * docs * Apply suggestions from code review Co-authored-by: Ethan Harris <[email protected]> * Apply suggestions from code review * fix * max * check * add and update tests * max * check * check * check * chlog * tests * update exception message * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent b37c35a commit a642349

File tree

7 files changed

+115
-75
lines changed

7 files changed

+115
-75
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535

3636
- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))
3737

38+
- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2840](https://github.com/PyTorchLightning/pytorch-lightning/pull/2840))
39+
3840
- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935))
3941

4042
### Changed

docs/source/sequences.rst

+15-6
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag.
4949
.. note:: If you need to modify how the batch is split,
5050
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
5151

52-
.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
53-
a `hiddens` arg.
52+
.. note:: Using this feature requires updating your LightningModule's
53+
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
5454

5555
----------
5656

@@ -59,10 +59,13 @@ Iterable Datasets
5959
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
6060
option when using sequential data.
6161

62-
.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int
63-
(specifying the number of training batches to run before validation) when initializing the Trainer.
64-
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
65-
the validation interval when val_check_interval is less than one.
62+
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
63+
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
64+
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
65+
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
66+
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
67+
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
68+
Here mode can be train/val/test.
6669

6770
.. testcode::
6871

@@ -87,3 +90,9 @@ option when using sequential data.
8790

8891
# Set val_check_interval
8992
trainer = Trainer(val_check_interval=100)
93+
94+
# Set limit_val_batches to 0.0 or 0
95+
trainer = Trainer(limit_val_batches=0.0)
96+
97+
# Set limit_val_batches as an int
98+
trainer = Trainer(limit_val_batches=100)

pytorch_lightning/core/lightning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1771,7 +1771,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
17711771
elif self.example_input_array is not None:
17721772
input_data = self.example_input_array
17731773
else:
1774-
raise ValueError('input_sample and example_input_array tensors are both missing.')
1774+
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')
17751775

17761776
if 'example_outputs' not in kwargs:
17771777
self.eval()

pytorch_lightning/trainer/data_loading.py

+20-45
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,6 @@ class TrainerDataLoadingMixin(ABC):
103103
def is_overridden(self, *args):
104104
"""Warning: this is just empty shell for code implemented in other class."""
105105

106-
def _check_batch_limits(self, name: str) -> None:
107-
# TODO: verify it is still needed and deprecate it..
108-
value = getattr(self, name)
109-
110-
# ints are fine
111-
if isinstance(value, int):
112-
return
113-
114-
msg = f'`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}. (or pass in an int)'
115-
if name == 'val_check_interval':
116-
msg += ' If you want to disable validation set `limit_val_batches` to 0.0 instead.'
117-
118-
if not 0. <= value <= 1.:
119-
raise ValueError(msg)
120-
121106
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
122107
on_windows = platform.system() == 'Windows'
123108

@@ -212,18 +197,18 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
212197
# automatically add samplers
213198
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
214199

200+
self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
215201
self._worker_check(self.train_dataloader, 'train dataloader')
216-
self._check_batch_limits('limit_train_batches')
217202

218-
if not _has_len(self.train_dataloader):
219-
self.num_training_batches = float('inf')
220-
else:
221-
# try getting the length
222-
if isinstance(self.limit_train_batches, float):
223-
self.num_training_batches = len(self.train_dataloader)
224-
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
225-
else:
226-
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
203+
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
204+
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
205+
elif self.num_training_batches != float('inf'):
206+
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
207+
elif self.limit_train_batches != 1.0:
208+
raise MisconfigurationException(
209+
'When using an IterableDataset for `limit_train_batches`,'
210+
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
211+
' `num_training_batches` to use.')
227212

228213
# determine when to check validation
229214
# if int passed in, val checks that often
@@ -241,13 +226,10 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
241226
self.val_check_batch = float('inf')
242227
else:
243228
raise MisconfigurationException(
244-
'When using an infinite DataLoader (e.g. with an IterableDataset'
245-
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
229+
'When using an IterableDataset for `train_dataloader`,'
246230
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
247231
' checking validation every k training batches.')
248232
else:
249-
self._check_batch_limits('val_check_interval')
250-
251233
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
252234
self.val_check_batch = max(1, self.val_check_batch)
253235

@@ -308,20 +290,16 @@ def _reset_eval_dataloader(
308290
# percent or num_steps
309291
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
310292

311-
if num_batches != float('inf'):
312-
self._check_batch_limits(f'limit_{mode}_batches')
313-
314-
# limit num batches either as a percent or num steps
315-
if isinstance(limit_eval_batches, float):
316-
num_batches = int(num_batches * limit_eval_batches)
317-
else:
318-
num_batches = min(len(dataloader), limit_eval_batches)
319-
320-
elif limit_eval_batches not in (0.0, 1.0):
293+
# limit num batches either as a percent or num steps
294+
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
295+
num_batches = min(num_batches, int(limit_eval_batches))
296+
elif num_batches != float('inf'):
297+
num_batches = int(num_batches * limit_eval_batches)
298+
elif limit_eval_batches != 1.0:
321299
raise MisconfigurationException(
322-
'When using an infinite DataLoader (e.g. with an IterableDataset'
323-
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
324-
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
300+
'When using an IterableDataset for `limit_{mode}_batches`,'
301+
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
302+
f' `num_{mode}_batches` to use.')
325303

326304
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
327305
min_pct = 1.0 / len(dataloader)
@@ -388,9 +366,6 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
388366
def determine_data_use_amount(self, overfit_batches: float) -> None:
389367
"""Use less data for debugging purposes"""
390368
if overfit_batches > 0:
391-
if isinstance(overfit_batches, float) and overfit_batches > 1:
392-
raise ValueError('`overfit_batches` when used as a percentage must'
393-
f' be in range 0.0 < x < 1.0 but got {overfit_batches:.3f}.')
394369
self.limit_train_batches = overfit_batches
395370
self.limit_val_batches = overfit_batches
396371
self.limit_test_batches = overfit_batches

pytorch_lightning/trainer/trainer.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,6 @@ def __init__(
534534
# logging
535535
self.configure_logger(logger)
536536
self.log_save_interval = log_save_interval
537-
self.val_check_interval = val_check_interval
538537
self.row_log_interval = row_log_interval
539538

540539
# how much of the data to use
@@ -547,9 +546,6 @@ def __init__(
547546
)
548547
overfit_batches = overfit_pct
549548

550-
# convert floats to ints
551-
self.overfit_batches = _determine_limit_batches(overfit_batches)
552-
553549
# TODO: remove in 0.10.0
554550
if val_percent_check is not None:
555551
rank_zero_warn(
@@ -577,9 +573,11 @@ def __init__(
577573
)
578574
limit_train_batches = train_percent_check
579575

580-
self.limit_test_batches = _determine_limit_batches(limit_test_batches)
581-
self.limit_val_batches = _determine_limit_batches(limit_val_batches)
582-
self.limit_train_batches = _determine_limit_batches(limit_train_batches)
576+
self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
577+
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
578+
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
579+
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
580+
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
583581
self.determine_data_use_amount(self.overfit_batches)
584582

585583
# AMP init
@@ -1430,12 +1428,12 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]:
14301428
return self.dataloader
14311429

14321430

1433-
def _determine_limit_batches(batches: Union[int, float]) -> Union[int, float]:
1431+
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
14341432
if 0 <= batches <= 1:
14351433
return batches
14361434
elif batches > 1 and batches % 1.0 == 0:
14371435
return int(batches)
14381436
else:
14391437
raise MisconfigurationException(
1440-
f'You have passed invalid value {batches}, it has to be in (0, 1) or nature number.'
1438+
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
14411439
)

tests/models/test_onnx.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_error_if_no_input(tmpdir):
8585
model = EvalModelTemplate()
8686
model.example_input_array = None
8787
file_path = os.path.join(tmpdir, "model.onxx")
88-
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
88+
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
8989
model.to_onnx(file_path)
9090

9191

tests/trainer/test_dataloaders.py

+69-13
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,15 @@ def test_fit_val_loader_only(tmpdir):
5353

5454

5555
@pytest.mark.parametrize("dataloader_options", [
56-
dict(val_check_interval=1.1),
5756
dict(val_check_interval=10000),
5857
])
5958
def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):
60-
6159
model = EvalModelTemplate()
62-
6360
trainer = Trainer(
6461
default_root_dir=tmpdir,
6562
max_epochs=1,
6663
**dataloader_options,
6764
)
68-
6965
with pytest.raises(ValueError):
7066
# fit model
7167
trainer.fit(model)
@@ -78,9 +74,13 @@ def test_dataloader_config_errors_runtime(tmpdir, dataloader_options):
7874
dict(limit_val_batches=1.2),
7975
dict(limit_test_batches=-0.1),
8076
dict(limit_test_batches=1.2),
77+
dict(val_check_interval=-0.1),
78+
dict(val_check_interval=1.2),
79+
dict(overfit_batches=-0.1),
80+
dict(overfit_batches=1.2),
8181
])
8282
def test_dataloader_config_errors_init(tmpdir, dataloader_options):
83-
with pytest.raises(MisconfigurationException):
83+
with pytest.raises(MisconfigurationException, match='passed invalid value'):
8484
Trainer(
8585
default_root_dir=tmpdir,
8686
max_epochs=1,
@@ -256,6 +256,62 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
256256
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
257257

258258

259+
@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
260+
pytest.param(0.0, 0.0, 0.0),
261+
pytest.param(1.0, 1.0, 1.0),
262+
])
263+
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
264+
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
265+
model = EvalModelTemplate()
266+
model.train_dataloader = model.train_dataloader__infinite
267+
model.val_dataloader = model.val_dataloader__infinite
268+
model.test_dataloader = model.test_dataloader__infinite
269+
270+
trainer = Trainer(
271+
default_root_dir=tmpdir,
272+
max_epochs=1,
273+
limit_train_batches=limit_train_batches,
274+
limit_val_batches=limit_val_batches,
275+
limit_test_batches=limit_test_batches,
276+
)
277+
278+
results = trainer.fit(model)
279+
assert results == 1
280+
assert trainer.num_training_batches == (0 if limit_train_batches == 0.0 else float('inf'))
281+
assert trainer.num_val_batches[0] == (0 if limit_val_batches == 0.0 else float('inf'))
282+
283+
trainer.test(ckpt_path=None)
284+
assert trainer.num_test_batches[0] == (0 if limit_test_batches == 0.0 else float('inf'))
285+
286+
287+
@pytest.mark.parametrize(['limit_train_batches', 'limit_val_batches', 'limit_test_batches'], [
288+
pytest.param(0, 0, 0),
289+
pytest.param(10, 10, 10),
290+
])
291+
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
292+
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
293+
model = EvalModelTemplate()
294+
model.train_dataloader = model.train_dataloader__infinite
295+
model.val_dataloader = model.val_dataloader__infinite
296+
model.test_dataloader = model.test_dataloader__infinite
297+
298+
trainer = Trainer(
299+
default_root_dir=tmpdir,
300+
max_epochs=1,
301+
limit_train_batches=limit_train_batches,
302+
limit_val_batches=limit_val_batches,
303+
limit_test_batches=limit_test_batches,
304+
)
305+
306+
results = trainer.fit(model)
307+
assert results
308+
assert trainer.num_training_batches == limit_train_batches
309+
assert trainer.num_val_batches[0] == limit_val_batches
310+
311+
trainer.test(ckpt_path=None)
312+
assert trainer.num_test_batches[0] == limit_test_batches
313+
314+
259315
@pytest.mark.parametrize(
260316
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
261317
[
@@ -266,7 +322,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
266322
]
267323
)
268324
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
269-
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
325+
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
270326
model = EvalModelTemplate()
271327
model.val_dataloader = model.val_dataloader__multiple_mixed_length
272328
model.test_dataloader = model.test_dataloader__multiple_mixed_length
@@ -307,7 +363,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
307363
]
308364
)
309365
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
310-
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
366+
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
311367
os.environ['PL_DEV_DEBUG'] = '1'
312368

313369
model = EvalModelTemplate()
@@ -436,7 +492,7 @@ def test_train_inf_dataloader_error(tmpdir):
436492

437493
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
438494

439-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
495+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
440496
trainer.fit(model)
441497

442498

@@ -447,7 +503,7 @@ def test_val_inf_dataloader_error(tmpdir):
447503

448504
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
449505

450-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
506+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
451507
trainer.fit(model)
452508

453509

@@ -458,7 +514,7 @@ def test_test_inf_dataloader_error(tmpdir):
458514

459515
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
460516

461-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
517+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
462518
trainer.test(model)
463519

464520

@@ -774,7 +830,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):
774830

775831
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)
776832

777-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
833+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
778834
trainer.fit(model)
779835

780836

@@ -785,7 +841,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):
785841

786842
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)
787843

788-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
844+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
789845
trainer.fit(model)
790846

791847

@@ -796,5 +852,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):
796852

797853
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)
798854

799-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
855+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
800856
trainer.test(model)

0 commit comments

Comments
 (0)