Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for iterable datasets when val_check_interval=1.0 #1283

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))

### Changed

Expand Down
21 changes: 12 additions & 9 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,19 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
'If you want to disable validation set `val_percent_check` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be an int. An int k specifies checking '
'validation every k training batches.')

self._percent_range_check('val_check_interval')
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
raise MisconfigurationException(
'When using an infinite DataLoader (e.g. with an IterableDataset or when '
'DataLoader does not implement `__len__`) for `train_dataloader`, '
'`Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies '
'checking validation every k training batches.')
else:
self._percent_range_check('val_check_interval')

self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

def _reset_eval_dataloader(self, model: LightningModule,
mode: str) -> Tuple[int, List[DataLoader]]:
Expand Down
23 changes: 19 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def run_training_epoch(self):
train_dataloader = train_dataloader.per_device_loader(device)

# run epoch
for batch_idx, batch in self.profiler.profile_iterable(
enumerate(train_dataloader), "get_train_batch"
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
Expand Down Expand Up @@ -429,8 +429,10 @@ def run_training_epoch(self):
# ---------------
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
should_check_val = not self.disable_validation and can_check_epoch
should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch)
can_check_val = not self.disable_validation and can_check_epoch
should_check_val = is_val_check_batch or early_stop_epoch
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val

# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
Expand Down Expand Up @@ -740,3 +742,16 @@ def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()


def _with_is_last(iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
10 changes: 9 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ class CurrentTestModel(
)
trainer.fit(model)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
Expand All @@ -383,6 +382,15 @@ class CurrentTestModel(
# verify training completed
assert result == 1

trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1
)
result = trainer.fit(model)

# verify training completed
assert result == 1


def test_inf_val_dataloader(tmpdir):
"""Test inf val data loader (e.g. IterableDataset)"""
Expand Down