Skip to content

Commit b2e9607

Browse files
authored
Refactor dataloading (#955)
* Refactor dataloading * Refactor dataloading * Refactor dataloading * Add shuffle to test
1 parent be24456 commit b2e9607

File tree

5 files changed

+115
-83
lines changed

5 files changed

+115
-83
lines changed

pytorch_lightning/trainer/data_loading.py

+41-67
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,10 @@
1-
import warnings
21
from abc import ABC
32

43
import torch.distributed as dist
4+
from torch.utils.data import SequentialSampler, DataLoader
55
from torch.utils.data.distributed import DistributedSampler
6-
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader, BatchSampler
7-
from pytorch_lightning.utilities.debugging import MisconfigurationException
86

9-
try:
10-
# loading for pyTorch 1.3
11-
from torch.utils.data import IterableDataset
12-
except ImportError:
13-
# loading for pyTorch 1.1
14-
import torch
15-
warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,'
16-
' please upgrade to 1.2+' % torch.__version__, ImportWarning)
17-
EXIST_ITER_DATASET = False
18-
else:
19-
EXIST_ITER_DATASET = True
7+
from pytorch_lightning.utilities.debugging import MisconfigurationException
208

219
try:
2210
from apex import amp
@@ -90,36 +78,19 @@ def call_prepare_data(self, model):
9078
model.prepare_data()
9179

9280
def auto_add_sampler(self, dataloader, train):
93-
# do nothing when user gives a sampler
94-
dl_args = {
95-
'dataset': dataloader.dataset,
96-
'batch_size': dataloader.batch_size,
97-
'shuffle': False,
98-
'num_workers': dataloader.num_workers,
99-
'collate_fn': dataloader.collate_fn,
100-
'pin_memory': dataloader.pin_memory,
101-
'drop_last': dataloader.drop_last,
102-
'timeout': dataloader.timeout,
103-
'worker_init_fn': dataloader.worker_init_fn
104-
}
105-
106-
if train:
107-
if self.use_ddp or self.use_ddp2:
108-
sampler = DistributedSampler(dataloader.dataset)
109-
dl_args['shuffle'] = False
81+
if self.use_ddp or self.use_ddp2 or self.use_tpu:
82+
dl_args = {
83+
'dataset': dataloader.dataset,
84+
'batch_size': dataloader.batch_size,
85+
'shuffle': False,
86+
'num_workers': dataloader.num_workers,
87+
'collate_fn': dataloader.collate_fn,
88+
'pin_memory': dataloader.pin_memory,
89+
'drop_last': dataloader.drop_last,
90+
'timeout': dataloader.timeout,
91+
'worker_init_fn': dataloader.worker_init_fn
92+
}
11093

111-
elif self.use_tpu:
112-
sampler = DistributedSampler(
113-
dataloader.dataset,
114-
num_replicas=xm.xrt_world_size(),
115-
rank=xm.get_ordinal()
116-
)
117-
dl_args['shuffle'] = False
118-
else:
119-
sampler = RandomSampler(dataloader.dataset)
120-
121-
# on not train
122-
else:
12394
if self.use_tpu:
12495
sampler = DistributedSampler(
12596
dataloader.dataset,
@@ -128,12 +99,16 @@ def auto_add_sampler(self, dataloader, train):
12899
)
129100
dl_args['shuffle'] = False
130101
else:
131-
sampler = SequentialSampler(dataloader.dataset)
102+
if train:
103+
sampler = DistributedSampler(dataloader.dataset)
104+
dl_args['shuffle'] = False
105+
else:
106+
sampler = SequentialSampler(dataloader.dataset)
132107

133-
dl_args['sampler'] = sampler
108+
dl_args['sampler'] = sampler
134109

135-
new_dataloader = DataLoader(**dl_args)
136-
return new_dataloader
110+
dataloader = DataLoader(**dl_args)
111+
return dataloader
137112

138113
def reset_train_dataloader(self, model):
139114
"""
@@ -148,12 +123,12 @@ def reset_train_dataloader(self, model):
148123
# automatically add samplers
149124
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
150125

151-
# determine number of training batches
152-
if EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset):
126+
self._percent_range_check('train_percent_check')
127+
128+
if self.is_infinite_dataloader(self.train_dataloader):
153129
self.num_training_batches = float('inf')
154130
else:
155-
self._percent_range_check('train_percent_check')
156-
131+
# try getting the length
157132
self.num_training_batches = len(self.train_dataloader)
158133
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
159134

@@ -168,27 +143,26 @@ def reset_train_dataloader(self, model):
168143
f"to the number of the training batches ({self.num_training_batches}). "
169144
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
170145
else:
146+
if self.is_infinite_dataloader(self.train_dataloader):
147+
m = '''
148+
When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
149+
does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
150+
must be an int. An int k specifies checking validation every k training batches.
151+
'''
152+
raise MisconfigurationException(m)
153+
171154
self._percent_range_check('val_check_interval')
172155

173156
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
174157
self.val_check_batch = max(1, self.val_check_batch)
175158

176-
# support IterableDataset for train data
177-
self.is_iterable_train_dataloader = (
178-
EXIST_ITER_DATASET and isinstance(self.train_dataloader.dataset, IterableDataset)
179-
)
180-
if self.is_iterable_dataloader(self.train_dataloader) and not isinstance(self.val_check_interval, int):
181-
m = '''
182-
When using an iterableDataset for `train_dataloader`,
183-
`Trainer(val_check_interval)` must be an int.
184-
An int k specifies checking validation every k training batches
185-
'''
186-
raise MisconfigurationException(m)
187-
188-
def is_iterable_dataloader(self, dataloader):
189-
return (
190-
EXIST_ITER_DATASET and isinstance(dataloader.dataset, IterableDataset)
191-
)
159+
def is_infinite_dataloader(self, dataloader):
160+
try:
161+
# try getting the length
162+
_ = len(dataloader)
163+
return False
164+
except TypeError as e:
165+
return True
192166

193167
def reset_val_dataloader(self, model):
194168
"""

pytorch_lightning/trainer/trainer.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1114,19 +1114,14 @@ def run_pretrain_routine(self, model: LightningModule):
11141114
self.run_evaluation(test_mode=True)
11151115
return
11161116

1117-
# load the dataloaders
1118-
self.reset_train_dataloader(ref_model)
1119-
self.reset_val_dataloader(ref_model)
1120-
11211117
# check if we should run validation during training
1122-
self.disable_validation = self.num_val_batches == 0 or not self.is_overriden('validation_step')
1123-
self.disable_validation = self.disable_validation and not self.fast_dev_run
1118+
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run
11241119

11251120
# run tiny validation (if validation defined)
11261121
# to make sure program won't crash during val
11271122
ref_model.on_sanity_check_start()
1128-
ref_model.on_train_start()
11291123
if not self.disable_validation and self.num_sanity_val_steps > 0:
1124+
self.reset_val_dataloader(ref_model)
11301125
# init progress bars for validation sanity check
11311126
pbar = tqdm(desc='Validation sanity check',
11321127
total=self.num_sanity_val_steps * len(self.val_dataloaders),

pytorch_lightning/trainer/training_loop.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def is_function_implemented(self, m):
271271
pass
272272

273273
@abstractmethod
274-
def is_iterable_dataloader(self, dataloader):
274+
def is_infinite_dataloader(self, dataloader):
275275
# this is just empty shell for code from other class
276276
pass
277277

@@ -325,6 +325,11 @@ def reset_train_dataloader(self, model):
325325
# this is just empty shell for code from other class
326326
pass
327327

328+
@abstractmethod
329+
def reset_val_dataloader(self, model):
330+
# this is just empty shell for code from other class
331+
pass
332+
328333
@abstractmethod
329334
def has_arg(self, f_name, arg_name):
330335
# this is just empty shell for code from other class
@@ -334,11 +339,17 @@ def train(self):
334339
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
335340
' but will start from "0" in v0.8.0.', DeprecationWarning)
336341

342+
# get model
343+
model = self.get_model()
344+
345+
# load data
346+
self.reset_train_dataloader(model)
347+
self.reset_val_dataloader(model)
348+
337349
# Train begin callbacks
350+
model.on_train_start()
338351
self.on_train_start()
339352

340-
# get model
341-
model = self.get_model()
342353
try:
343354
# run all epochs
344355
for epoch in range(self.current_epoch, self.max_epochs):
@@ -347,9 +358,6 @@ def train(self):
347358
and hasattr(self.train_dataloader.sampler, 'set_epoch'):
348359
self.train_dataloader.sampler.set_epoch(epoch)
349360

350-
# get model
351-
model = self.get_model()
352-
353361
# update training progress in trainer and model
354362
model.current_epoch = epoch
355363
self.current_epoch = epoch
@@ -370,8 +378,8 @@ def train(self):
370378
if self.fast_dev_run:
371379
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
372380
num_iterations = 2
373-
elif self.is_iterable_dataloader(self.train_dataloader):
374-
# for iterable train loader, the progress bar never ends
381+
elif self.is_infinite_dataloader(self.train_dataloader):
382+
# for infinite train loader, the progress bar never ends
375383
num_iterations = None
376384
else:
377385
num_iterations = self.total_batches
@@ -380,7 +388,7 @@ def train(self):
380388
# .reset() doesn't work on disabled progress bar so we should check
381389
if not self.main_progress_bar.disable:
382390
self.main_progress_bar.reset(num_iterations)
383-
desc = f'Epoch {epoch + 1}' if not self.is_iterable_dataloader(self.train_dataloader) else ''
391+
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
384392
self.main_progress_bar.set_description(desc)
385393

386394
# changing gradient according accumulation_scheduler

tests/models/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def _dataloader(self, train):
168168
loader = DataLoader(
169169
dataset=dataset,
170170
batch_size=batch_size,
171+
shuffle=True
171172
)
172173

173174
return loader

tests/test_trainer.py

+54
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,60 @@ def test_model_freeze_unfreeze():
380380
model.unfreeze()
381381

382382

383+
def test_inf_train_dataloader(tmpdir):
384+
"""Test inf train data loader (e.g. IterableDataset)"""
385+
tutils.reset_seed()
386+
387+
class CurrentTestModel(LightningTestModel):
388+
def train_dataloader(self):
389+
dataloader = self._dataloader(train=True)
390+
391+
class CustomInfDataLoader:
392+
def __init__(self, dataloader):
393+
self.dataloader = dataloader
394+
self.iter = iter(dataloader)
395+
self.count = 0
396+
397+
def __iter__(self):
398+
self.count = 0
399+
return self
400+
401+
def __next__(self):
402+
if self.count >= 5:
403+
raise StopIteration
404+
self.count = self.count + 1
405+
try:
406+
return next(self.iter)
407+
except StopIteration:
408+
self.iter = iter(self.dataloader)
409+
return next(self.iter)
410+
411+
return CustomInfDataLoader(dataloader)
412+
413+
hparams = tutils.get_hparams()
414+
model = CurrentTestModel(hparams)
415+
416+
# fit model
417+
with pytest.raises(MisconfigurationException):
418+
trainer = Trainer(
419+
default_save_path=tmpdir,
420+
max_epochs=1,
421+
val_check_interval=0.5
422+
)
423+
trainer.fit(model)
424+
425+
# logger file to get meta
426+
trainer = Trainer(
427+
default_save_path=tmpdir,
428+
max_epochs=1,
429+
val_check_interval=50,
430+
)
431+
result = trainer.fit(model)
432+
433+
# verify training completed
434+
assert result == 1
435+
436+
383437
def test_multiple_val_dataloader(tmpdir):
384438
"""Verify multiple val_dataloader."""
385439
tutils.reset_seed()

0 commit comments

Comments
 (0)