Skip to content

Commit 9ebe93c

Browse files
ethanwharrisakarnachev
authored and
akarnachev
committed
Add support for iterable datasets when val_check_interval=1.0 (Lightning-AI#1283)
* Add support for iterable datasets when val_check_interval=1.0 * Update CHANGELOG.md
1 parent 3a4ee01 commit 9ebe93c

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
1818
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
1919
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
20+
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
2021

2122
### Changed
2223

pytorch_lightning/trainer/data_loading.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,19 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
136136
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
137137
else:
138138
if not _has_len(self.train_dataloader):
139-
raise MisconfigurationException(
140-
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
141-
'DataLoader does not implement `__len__`) for `train_dataloader`, '
142-
'`Trainer(val_check_interval)` must be an int. An int k specifies checking '
143-
'validation every k training batches.')
144-
145-
self._percent_range_check('val_check_interval')
139+
if self.val_check_interval == 1.0:
140+
self.val_check_batch = float('inf')
141+
else:
142+
raise MisconfigurationException(
143+
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
144+
'DataLoader does not implement `__len__`) for `train_dataloader`, '
145+
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
146+
'checking validation every k training batches.')
147+
else:
148+
self._percent_range_check('val_check_interval')
146149

147-
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
148-
self.val_check_batch = max(1, self.val_check_batch)
150+
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
151+
self.val_check_batch = max(1, self.val_check_batch)
149152

150153
def _reset_eval_dataloader(self, model: LightningModule,
151154
mode: str) -> Tuple[int, List[DataLoader]]:

pytorch_lightning/trainer/training_loop.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,8 @@ def run_training_epoch(self):
400400
train_dataloader = train_dataloader.per_device_loader(device)
401401

402402
# run epoch
403-
for batch_idx, batch in self.profiler.profile_iterable(
404-
enumerate(train_dataloader), "get_train_batch"
403+
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
404+
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
405405
):
406406
# stop epoch if we limited the number of training batches
407407
if batch_idx >= self.num_training_batches:
@@ -429,8 +429,10 @@ def run_training_epoch(self):
429429
# ---------------
430430
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
431431
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
432-
should_check_val = not self.disable_validation and can_check_epoch
433-
should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch)
432+
can_check_val = not self.disable_validation and can_check_epoch
433+
should_check_val = is_val_check_batch or early_stop_epoch
434+
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
435+
should_check_val = can_check_val and should_check_val
434436

435437
# fast_dev_run always forces val checking after train batch
436438
if self.fast_dev_run or should_check_val:
@@ -740,3 +742,16 @@ def call_checkpoint_callback(self):
740742
if self.checkpoint_callback is not None:
741743
self.checkpoint_callback.on_validation_end(self, self.get_model())
742744
self.on_validation_end()
745+
746+
747+
def _with_is_last(iterable):
748+
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
749+
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
750+
it = iter(iterable)
751+
last = next(it)
752+
for val in it:
753+
# yield last and has next
754+
yield last, False
755+
last = val
756+
# yield last, no longer has next
757+
yield last, True

tests/trainer/test_dataloaders.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,6 @@ class CurrentTestModel(
372372
)
373373
trainer.fit(model)
374374

375-
# logger file to get meta
376375
trainer = Trainer(
377376
default_save_path=tmpdir,
378377
max_epochs=1,
@@ -383,6 +382,15 @@ class CurrentTestModel(
383382
# verify training completed
384383
assert result == 1
385384

385+
trainer = Trainer(
386+
default_save_path=tmpdir,
387+
max_epochs=1
388+
)
389+
result = trainer.fit(model)
390+
391+
# verify training completed
392+
assert result == 1
393+
386394

387395
def test_inf_val_dataloader(tmpdir):
388396
"""Test inf val data loader (e.g. IterableDataset)"""

0 commit comments

Comments
 (0)