diff --git a/CHANGELOG.md b/CHANGELOG.md index 90c9f49566bae..f4f600282a39d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)). - Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114)) - Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132)) +- Fixed a bug that created an extra dataloader with active `reload_dataloaders_every_epoch` ([#1181](https://github.com/PyTorchLightning/pytorch-lightning/issues/1181) - Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191)) - Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251)) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index bfbca4c25db29..f36ed898cffbc 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -338,14 +338,14 @@ def run_evaluation(self, test_mode: bool = False): # select dataloaders if test_mode: - if self.reload_dataloaders_every_epoch or self.test_dataloaders is None: + if self.test_dataloaders is None: self.reset_test_dataloader(model) dataloaders = self.test_dataloaders max_batches = self.num_test_batches else: # val - if self.reload_dataloaders_every_epoch or self.val_dataloaders is None: + if self.val_dataloaders is None: self.reset_val_dataloader(model) dataloaders = self.val_dataloaders @@ -399,6 +399,15 @@ def run_evaluation(self, test_mode: bool = False): else: self.val_progress_bar.close() + # eventual dataset reloading + if test_mode: + if self.reload_dataloaders_every_epoch: + self.reset_test_dataloader(model) + else: + # val + if self.reload_dataloaders_every_epoch: + self.reset_val_dataloader(model) + # Validation/Test end callbacks if test_mode: self.on_test_end() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9994af6c2d108..5081104b77ccd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -275,7 +275,6 @@ def __init__( " and this method will be removed in v0.8.0", DeprecationWarning) self.gradient_clip = gradient_clip - self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.progress_bar_refresh_rate = progress_bar_refresh_rate self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm @@ -320,6 +319,8 @@ def __init__( " NaN grads will be printed automatically when detected.", DeprecationWarning) + self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch + self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.shown_warnings = set() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index cc44882ea5a58..84a8139645016 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -289,7 +289,9 @@ def train(self): model = self.get_model() # load data - self.reset_train_dataloader(model) + # if reload_dataloaders_every_epoch, this is moved to the epoch loop + if not self.reload_dataloaders_every_epoch: + self.reset_train_dataloader(model) self.reset_val_dataloader(model) # Train start events @@ -305,6 +307,9 @@ def train(self): try: # run all epochs for epoch in range(self.current_epoch, self.max_epochs): + # reset train dataloader + if self.reload_dataloaders_every_epoch: + self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) if self.use_ddp \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): @@ -393,10 +398,6 @@ def run_training_epoch(self): if self.is_function_implemented('on_epoch_start'): self.get_model().on_epoch_start() - # reset train dataloader - if self.reload_dataloaders_every_epoch: - self.reset_train_dataloader(self.get_model()) - # track local dataloader so TPU can wrap each epoch train_dataloader = self.train_dataloader