Skip to content

Commit f2191b0

Browse files
BordawilliamFalcon
authored andcommitted
fix for pyTorch 1.2 (#549)
* min pytorch 1.2 * fix IterableDataset * upgrade torchvision * fix msg
1 parent 55f3ffd commit f2191b0

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

pytorch_lightning/trainer/data_loading_mixin.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
import warnings
22

33
import torch.distributed as dist
4-
from torch.utils.data import IterableDataset
4+
try:
5+
# loading for pyTorch 1.3
6+
from torch.utils.data import IterableDataset
7+
except ImportError:
8+
# loading for pyTorch 1.1
9+
import torch
10+
warnings.warn('Your version of pyTorch %s does not support `IterableDataset`,'
11+
' please upgrade to 1.2+' % torch.__version__, ImportWarning)
12+
EXIST_ITER_DATASET = False
13+
else:
14+
EXIST_ITER_DATASET = True
515
from torch.utils.data.distributed import DistributedSampler
616

717
from pytorch_lightning.utilities.debugging import MisconfigurationException
@@ -24,7 +34,7 @@ def init_train_dataloader(self, model):
2434
self.get_train_dataloader = model.train_dataloader
2535

2636
# determine number of training batches
27-
if isinstance(self.get_train_dataloader().dataset, IterableDataset):
37+
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset):
2838
self.nb_training_batches = float('inf')
2939
else:
3040
self.nb_training_batches = len(self.get_train_dataloader())
@@ -167,7 +177,8 @@ def get_dataloaders(self, model):
167177
self.get_val_dataloaders()
168178

169179
# support IterableDataset for train data
170-
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader().dataset, IterableDataset)
180+
self.is_iterable_train_dataloader = (
181+
EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset))
171182
if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int):
172183
m = '''
173184
When using an iterableDataset for train_dataloader,

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
scikit-learn>=0.20.2
22
tqdm>=4.35.0
33
numpy>=1.16.4
4-
torch>=1.1
5-
torchvision>=0.3.0
4+
torch>=1.2
5+
torchvision>=0.4.0
66
pandas>=0.24 # lower version do not support py3.7
77
test-tube>=0.6.9
88
# future>=0.17.1 # required for buildins in setup.py

0 commit comments

Comments
 (0)