Skip to content

Commit d9f60b4

Browse files
justusschockBorda
authored andcommitted
generalize reinstantiation of dataloader (Lightning-AI#1346)
* generalize reinstantiation of dataloader * fix condition * add test * update changelog * fix changelog Co-authored-by: J. Borovec <[email protected]>
1 parent 68420c3 commit d9f60b4

File tree

3 files changed

+50
-19
lines changed

3 files changed

+50
-19
lines changed

CHANGELOG.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
2323
- Added testing for python 3.8 ([#915](https://github.com/PyTorchLightning/pytorch-lightning/pull/915))
2424
- Added a `training_epoch_end` method which is the mirror of `validation_epoch_end`. ([#1357](https://github.com/PyTorchLightning/pytorch-lightning/pull/1357))
25+
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
26+
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
27+
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
28+
2529
### Changed
2630

2731
- Changed `progress_bar_refresh_rate` trainer flag to disable progress bar when set to 0. ([#1108](https://github.com/PyTorchLightning/pytorch-lightning/pull/1108))
2832
- Enhanced `load_from_checkpoint` to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
2933
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
30-
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
3134
- Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
32-
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
33-
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
34-
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
35-
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
3635
- Allow to upload models on W&B ([#1339](https://github.com/PyTorchLightning/pytorch-lightning/pull/1339))
37-
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
3836
- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
39-
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
37+
- Did not always create a DataLoader during reinstantiation, but the same type as before (if subclass of DataLoader) ([#1346](https://github.com/PyTorchLightning/pytorch-lightning/pull/1346))
38+
- Did not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
4039
- Remove default Adam optimizer ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
4140
- Give warnings for unimplemented required lightning methods ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
4241
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
@@ -314,6 +313,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
314313
### Added
315314

316315
- Added the flag `log_gpu_memory` to `Trainer` to deactivate logging of GPU memory utilization
316+
- Added SLURM resubmit functionality (port from test-tube)
317317
- Added optional weight_save_path to trainer to remove the need for a checkpoint_callback when using cluster training
318318
- Added option to use single gpu per node with `DistributedDataParallel`
319319

pytorch_lightning/trainer/data_loading.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,10 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
8484

8585
if need_dist_sampler and no_sampler_added:
8686

87+
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
88+
8789
dl_args = {
88-
'dataset': dataloader.dataset,
89-
'batch_size': dataloader.batch_size,
90-
'shuffle': False,
91-
'num_workers': dataloader.num_workers,
92-
'collate_fn': dataloader.collate_fn,
93-
'pin_memory': dataloader.pin_memory,
94-
'drop_last': dataloader.drop_last,
95-
'timeout': dataloader.timeout,
96-
'worker_init_fn': dataloader.worker_init_fn
90+
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
9791
}
9892

9993
if self.use_tpu:
@@ -102,13 +96,11 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
10296
num_replicas=xm.xrt_world_size(),
10397
rank=xm.get_ordinal()
10498
)
105-
dl_args['shuffle'] = False
10699
else:
107100
sampler = DistributedSampler(dataloader.dataset)
108-
dl_args['shuffle'] = False
109101

110102
dl_args['sampler'] = sampler
111-
dataloader = DataLoader(**dl_args)
103+
dataloader = type(dataloader)(**dl_args)
112104

113105
return dataloader
114106

tests/trainer/test_dataloaders.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import torch
23

34
import tests.base.utils as tutils
45
from pytorch_lightning import Trainer
@@ -482,3 +483,41 @@ class CurrentTestModel(
482483
test_percent_check=0.5
483484
)
484485
trainer.fit(model)
486+
487+
488+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
489+
def test_dataloader_reinit_for_subclass():
490+
491+
class CustomDataLoader(torch.utils.data.DataLoader):
492+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
493+
batch_sampler=None, num_workers=0, collate_fn=None,
494+
pin_memory=False, drop_last=False, timeout=0,
495+
worker_init_fn=None, dummy_kwarg=None):
496+
super().__init__(dataset,
497+
batch_size,
498+
shuffle,
499+
sampler,
500+
batch_sampler,
501+
num_workers,
502+
collate_fn,
503+
pin_memory,
504+
drop_last,
505+
timeout,
506+
worker_init_fn)
507+
508+
self.dummy_kwarg = dummy_kwarg
509+
510+
trainer = Trainer(gpus=[0, 1],
511+
num_nodes=1,
512+
distributed_backend='ddp')
513+
514+
class CustomDummyObj:
515+
sampler = None
516+
517+
result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
518+
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
519+
520+
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True)
521+
assert isinstance(result, torch.utils.data.DataLoader)
522+
assert isinstance(result, CustomDataLoader)
523+
assert hasattr(result, 'dummy_kwarg')

0 commit comments

Comments
 (0)