1
- import warnings
2
1
from abc import ABC
3
2
4
3
import torch .distributed as dist
4
+ from torch .utils .data import SequentialSampler , DataLoader
5
5
from torch .utils .data .distributed import DistributedSampler
6
- from torch .utils .data import RandomSampler , SequentialSampler , DataLoader , BatchSampler
7
- from pytorch_lightning .utilities .debugging import MisconfigurationException
8
6
9
- try :
10
- # loading for pyTorch 1.3
11
- from torch .utils .data import IterableDataset
12
- except ImportError :
13
- # loading for pyTorch 1.1
14
- import torch
15
- warnings .warn ('Your version of pyTorch %s does not support `IterableDataset`,'
16
- ' please upgrade to 1.2+' % torch .__version__ , ImportWarning )
17
- EXIST_ITER_DATASET = False
18
- else :
19
- EXIST_ITER_DATASET = True
7
+ from pytorch_lightning .utilities .debugging import MisconfigurationException
20
8
21
9
try :
22
10
from apex import amp
@@ -90,36 +78,19 @@ def call_prepare_data(self, model):
90
78
model .prepare_data ()
91
79
92
80
def auto_add_sampler (self , dataloader , train ):
93
- # do nothing when user gives a sampler
94
- dl_args = {
95
- 'dataset' : dataloader .dataset ,
96
- 'batch_size' : dataloader .batch_size ,
97
- 'shuffle' : False ,
98
- 'num_workers' : dataloader .num_workers ,
99
- 'collate_fn' : dataloader .collate_fn ,
100
- 'pin_memory' : dataloader .pin_memory ,
101
- 'drop_last' : dataloader .drop_last ,
102
- 'timeout' : dataloader .timeout ,
103
- 'worker_init_fn' : dataloader .worker_init_fn
104
- }
105
-
106
- if train :
107
- if self .use_ddp or self .use_ddp2 :
108
- sampler = DistributedSampler (dataloader .dataset )
109
- dl_args ['shuffle' ] = False
81
+ if self .use_ddp or self .use_ddp2 or self .use_tpu :
82
+ dl_args = {
83
+ 'dataset' : dataloader .dataset ,
84
+ 'batch_size' : dataloader .batch_size ,
85
+ 'shuffle' : False ,
86
+ 'num_workers' : dataloader .num_workers ,
87
+ 'collate_fn' : dataloader .collate_fn ,
88
+ 'pin_memory' : dataloader .pin_memory ,
89
+ 'drop_last' : dataloader .drop_last ,
90
+ 'timeout' : dataloader .timeout ,
91
+ 'worker_init_fn' : dataloader .worker_init_fn
92
+ }
110
93
111
- elif self .use_tpu :
112
- sampler = DistributedSampler (
113
- dataloader .dataset ,
114
- num_replicas = xm .xrt_world_size (),
115
- rank = xm .get_ordinal ()
116
- )
117
- dl_args ['shuffle' ] = False
118
- else :
119
- sampler = RandomSampler (dataloader .dataset )
120
-
121
- # on not train
122
- else :
123
94
if self .use_tpu :
124
95
sampler = DistributedSampler (
125
96
dataloader .dataset ,
@@ -128,12 +99,16 @@ def auto_add_sampler(self, dataloader, train):
128
99
)
129
100
dl_args ['shuffle' ] = False
130
101
else :
131
- sampler = SequentialSampler (dataloader .dataset )
102
+ if train :
103
+ sampler = DistributedSampler (dataloader .dataset )
104
+ dl_args ['shuffle' ] = False
105
+ else :
106
+ sampler = SequentialSampler (dataloader .dataset )
132
107
133
- dl_args ['sampler' ] = sampler
108
+ dl_args ['sampler' ] = sampler
134
109
135
- new_dataloader = DataLoader (** dl_args )
136
- return new_dataloader
110
+ dataloader = DataLoader (** dl_args )
111
+ return dataloader
137
112
138
113
def reset_train_dataloader (self , model ):
139
114
"""
@@ -148,12 +123,12 @@ def reset_train_dataloader(self, model):
148
123
# automatically add samplers
149
124
self .train_dataloader = self .auto_add_sampler (self .train_dataloader , train = True )
150
125
151
- # determine number of training batches
152
- if EXIST_ITER_DATASET and isinstance (self .train_dataloader .dataset , IterableDataset ):
126
+ self ._percent_range_check ('train_percent_check' )
127
+
128
+ if self .is_infinite_dataloader (self .train_dataloader ):
153
129
self .num_training_batches = float ('inf' )
154
130
else :
155
- self ._percent_range_check ('train_percent_check' )
156
-
131
+ # try getting the length
157
132
self .num_training_batches = len (self .train_dataloader )
158
133
self .num_training_batches = int (self .num_training_batches * self .train_percent_check )
159
134
@@ -168,27 +143,26 @@ def reset_train_dataloader(self, model):
168
143
f"to the number of the training batches ({ self .num_training_batches } ). "
169
144
f"If you want to disable validation set `val_percent_check` to 0.0 instead." )
170
145
else :
146
+ if self .is_infinite_dataloader (self .train_dataloader ):
147
+ m = '''
148
+ When using an infinite DataLoader (e.g. with an IterableDataset or when DataLoader
149
+ does not implement `__len__`) for `train_dataloader`, `Trainer(val_check_interval)`
150
+ must be an int. An int k specifies checking validation every k training batches.
151
+ '''
152
+ raise MisconfigurationException (m )
153
+
171
154
self ._percent_range_check ('val_check_interval' )
172
155
173
156
self .val_check_batch = int (self .num_training_batches * self .val_check_interval )
174
157
self .val_check_batch = max (1 , self .val_check_batch )
175
158
176
- # support IterableDataset for train data
177
- self .is_iterable_train_dataloader = (
178
- EXIST_ITER_DATASET and isinstance (self .train_dataloader .dataset , IterableDataset )
179
- )
180
- if self .is_iterable_dataloader (self .train_dataloader ) and not isinstance (self .val_check_interval , int ):
181
- m = '''
182
- When using an iterableDataset for `train_dataloader`,
183
- `Trainer(val_check_interval)` must be an int.
184
- An int k specifies checking validation every k training batches
185
- '''
186
- raise MisconfigurationException (m )
187
-
188
- def is_iterable_dataloader (self , dataloader ):
189
- return (
190
- EXIST_ITER_DATASET and isinstance (dataloader .dataset , IterableDataset )
191
- )
159
+ def is_infinite_dataloader (self , dataloader ):
160
+ try :
161
+ # try getting the length
162
+ _ = len (dataloader )
163
+ return False
164
+ except TypeError as e :
165
+ return True
192
166
193
167
def reset_val_dataloader (self , model ):
194
168
"""
0 commit comments