Skip to content

Commit 04935ea

Browse files
TevenLeScaoBorda
andauthored
fixed extra dataloader bug (#1196)
* fixed extra dataloader bug * Update pytorch_lightning/trainer/training_loop.py Co-Authored-By: Jirka Borovec <[email protected]> * updated CHANGELOG * Small non-repetition change self.get_model() => model as it was already defined * Update CHANGELOG.md * changed argument name to reload_train_dataloader_every_epoch * fixed doc underline too short * reverted to `reload_dataloaders_every_epoch` * fixed val and test reloading * fixed val and test reloading Co-authored-by: TevenLeScao <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent e48422d commit 04935ea

File tree

4 files changed

+20
-8
lines changed

4 files changed

+20
-8
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4343
- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).
4444
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
4545
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
46+
- Fixed a bug that created an extra dataloader with active `reload_dataloaders_every_epoch` ([#1181](https://github.com/PyTorchLightning/pytorch-lightning/issues/1181)
4647
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
4748
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
4849
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))

pytorch_lightning/trainer/evaluation_loop.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -338,14 +338,14 @@ def run_evaluation(self, test_mode: bool = False):
338338

339339
# select dataloaders
340340
if test_mode:
341-
if self.reload_dataloaders_every_epoch or self.test_dataloaders is None:
341+
if self.test_dataloaders is None:
342342
self.reset_test_dataloader(model)
343343

344344
dataloaders = self.test_dataloaders
345345
max_batches = self.num_test_batches
346346
else:
347347
# val
348-
if self.reload_dataloaders_every_epoch or self.val_dataloaders is None:
348+
if self.val_dataloaders is None:
349349
self.reset_val_dataloader(model)
350350

351351
dataloaders = self.val_dataloaders
@@ -399,6 +399,15 @@ def run_evaluation(self, test_mode: bool = False):
399399
else:
400400
self.val_progress_bar.close()
401401

402+
# eventual dataset reloading
403+
if test_mode:
404+
if self.reload_dataloaders_every_epoch:
405+
self.reset_test_dataloader(model)
406+
else:
407+
# val
408+
if self.reload_dataloaders_every_epoch:
409+
self.reset_val_dataloader(model)
410+
402411
# Validation/Test end callbacks
403412
if test_mode:
404413
self.on_test_end()

pytorch_lightning/trainer/trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ def __init__(
274274
" and this method will be removed in v0.8.0", DeprecationWarning)
275275
self.gradient_clip = gradient_clip
276276

277-
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
278277
self.progress_bar_refresh_rate = progress_bar_refresh_rate
279278
self.check_val_every_n_epoch = check_val_every_n_epoch
280279
self.track_grad_norm = track_grad_norm
@@ -319,6 +318,8 @@ def __init__(
319318
" NaN grads will be printed automatically when detected.",
320319
DeprecationWarning)
321320

321+
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
322+
322323
self.truncated_bptt_steps = truncated_bptt_steps
323324
self.resume_from_checkpoint = resume_from_checkpoint
324325
self.shown_warnings = set()

pytorch_lightning/trainer/training_loop.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ def train(self):
290290
model = self.get_model()
291291

292292
# load data
293-
self.reset_train_dataloader(model)
293+
# if reload_dataloaders_every_epoch, this is moved to the epoch loop
294+
if not self.reload_dataloaders_every_epoch:
295+
self.reset_train_dataloader(model)
294296
self.reset_val_dataloader(model)
295297

296298
# Train start events
@@ -306,6 +308,9 @@ def train(self):
306308
try:
307309
# run all epochs
308310
for epoch in range(self.current_epoch, self.max_epochs):
311+
# reset train dataloader
312+
if self.reload_dataloaders_every_epoch:
313+
self.reset_train_dataloader(model)
309314
# set seed for distributed sampler (enables shuffling for each epoch)
310315
if self.use_ddp \
311316
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
@@ -394,10 +399,6 @@ def run_training_epoch(self):
394399
if self.is_function_implemented('on_epoch_start'):
395400
self.get_model().on_epoch_start()
396401

397-
# reset train dataloader
398-
if self.reload_dataloaders_every_epoch:
399-
self.reset_train_dataloader(self.get_model())
400-
401402
# track local dataloader so TPU can wrap each epoch
402403
train_dataloader = self.train_dataloader
403404

0 commit comments

Comments
 (0)