Skip to content

Commit 5e63120

Browse files
justusschockBorda
authored and
akarnachev
committed
generalize reinstantiation of dataloader (Lightning-AI#1346)
* generalize reinstantiation of dataloader * fix condition * add test * update changelog * fix changelog Co-authored-by: J. Borovec <[email protected]>
1 parent b4a0413 commit 5e63120

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

pytorch_lightning/trainer/data_loading.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,10 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
8484

8585
if need_dist_sampler and no_sampler_added:
8686

87+
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
88+
8789
dl_args = {
88-
'dataset': dataloader.dataset,
89-
'batch_size': dataloader.batch_size,
90-
'shuffle': False,
91-
'num_workers': dataloader.num_workers,
92-
'collate_fn': dataloader.collate_fn,
93-
'pin_memory': dataloader.pin_memory,
94-
'drop_last': dataloader.drop_last,
95-
'timeout': dataloader.timeout,
96-
'worker_init_fn': dataloader.worker_init_fn
90+
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
9791
}
9892

9993
if self.use_tpu:
@@ -102,13 +96,11 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
10296
num_replicas=xm.xrt_world_size(),
10397
rank=xm.get_ordinal()
10498
)
105-
dl_args['shuffle'] = False
10699
else:
107100
sampler = DistributedSampler(dataloader.dataset)
108-
dl_args['shuffle'] = False
109101

110102
dl_args['sampler'] = sampler
111-
dataloader = DataLoader(**dl_args)
103+
dataloader = type(dataloader)(**dl_args)
112104

113105
return dataloader
114106

tests/trainer/test_dataloaders.py

+39
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import torch
23

34
import tests.base.utils as tutils
45
from pytorch_lightning import Trainer
@@ -482,3 +483,41 @@ class CurrentTestModel(
482483
test_percent_check=0.5
483484
)
484485
trainer.fit(model)
486+
487+
488+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
489+
def test_dataloader_reinit_for_subclass():
490+
491+
class CustomDataLoader(torch.utils.data.DataLoader):
492+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
493+
batch_sampler=None, num_workers=0, collate_fn=None,
494+
pin_memory=False, drop_last=False, timeout=0,
495+
worker_init_fn=None, dummy_kwarg=None):
496+
super().__init__(dataset,
497+
batch_size,
498+
shuffle,
499+
sampler,
500+
batch_sampler,
501+
num_workers,
502+
collate_fn,
503+
pin_memory,
504+
drop_last,
505+
timeout,
506+
worker_init_fn)
507+
508+
self.dummy_kwarg = dummy_kwarg
509+
510+
trainer = Trainer(gpus=[0, 1],
511+
num_nodes=1,
512+
distributed_backend='ddp')
513+
514+
class CustomDummyObj:
515+
sampler = None
516+
517+
result = trainer.auto_add_sampler(CustomDummyObj(), train=True)
518+
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
519+
520+
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000))), train=True)
521+
assert isinstance(result, torch.utils.data.DataLoader)
522+
assert isinstance(result, CustomDataLoader)
523+
assert hasattr(result, 'dummy_kwarg')

0 commit comments

Comments
 (0)