1
1
import warnings
2
2
3
3
import torch .distributed as dist
4
- from torch .utils .data import IterableDataset
4
+ try :
5
+ # loading for pyTorch 1.3
6
+ from torch .utils .data import IterableDataset
7
+ except ImportError :
8
+ # loading for pyTorch 1.1
9
+ import torch
10
+ warnings .warn ('Your version of pyTorch %s does not support `IterableDataset`,'
11
+ ' please upgrade to 1.2+' % torch .__version__ , ImportWarning )
12
+ EXIST_ITER_DATASET = False
13
+ else :
14
+ EXIST_ITER_DATASET = True
5
15
from torch .utils .data .distributed import DistributedSampler
6
16
7
17
from pytorch_lightning .utilities .debugging import MisconfigurationException
@@ -24,7 +34,7 @@ def init_train_dataloader(self, model):
24
34
self .get_train_dataloader = model .train_dataloader
25
35
26
36
# determine number of training batches
27
- if isinstance (self .get_train_dataloader ().dataset , IterableDataset ):
37
+ if EXIST_ITER_DATASET and isinstance (self .get_train_dataloader ().dataset , IterableDataset ):
28
38
self .nb_training_batches = float ('inf' )
29
39
else :
30
40
self .nb_training_batches = len (self .get_train_dataloader ())
@@ -167,7 +177,8 @@ def get_dataloaders(self, model):
167
177
self .get_val_dataloaders ()
168
178
169
179
# support IterableDataset for train data
170
- self .is_iterable_train_dataloader = isinstance (self .get_train_dataloader ().dataset , IterableDataset )
180
+ self .is_iterable_train_dataloader = (
181
+ EXIST_ITER_DATASET and isinstance (self .get_train_dataloader ().dataset , IterableDataset ))
171
182
if self .is_iterable_train_dataloader and not isinstance (self .val_check_interval , int ):
172
183
m = '''
173
184
When using an iterableDataset for train_dataloader,
0 commit comments