From d00b2d57d3495c438f36e62bbd77a7918999c7f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 05:57:42 +0200 Subject: [PATCH 01/11] add warning when getting checking len --- pytorch_lightning/trainer/data_loading.py | 26 +++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 06ab7b316e1c2..676b5390f01bf 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -41,19 +41,33 @@ HOROVOD_AVAILABLE = True +def _has_iterable_dataset(dataloader: DataLoader): + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ + and isinstance(dataloader.dataset, IterableDataset) + + def _has_len(dataloader: DataLoader) -> bool: """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader """ + it is a finite dataloader or infinite dataloader. """ + try: # try getting the length if len(dataloader) == 0: raise ValueError('`Dataloader` returned 0 length.' ' Please make sure that your Dataloader at least returns 1 batch') - return True + has_len = True except TypeError: - return False + has_len = False except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - return False + has_len = False + + if has_len and _has_iterable_dataset(dataloader): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len class TrainerDataLoadingMixin(ABC): @@ -131,9 +145,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't manipulate iterable datasets is_dataloader = isinstance(dataloader, DataLoader) - is_iterable_ds = False - if ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset'): - is_iterable_ds = isinstance(dataloader.dataset, IterableDataset) + is_iterable_ds = _has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader From 55953ebf8c0ecdec2d29235aeb5946716dfdd75f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 06:57:28 +0200 Subject: [PATCH 02/11] added test --- pytorch_lightning/trainer/data_loading.py | 6 +---- tests/trainer/test_dataloaders.py | 28 ++++++++++++++++++++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 676b5390f01bf..ef0103aaae72d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -297,11 +297,7 @@ def _reset_eval_dataloader( # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - try: - num_batches = len(dataloader) - except (TypeError, NotImplementedError): - num_batches = float('inf') - + num_batches = len(dataloader) if _has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index b36eca8a2e429..f9a4557bee026 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -3,10 +3,11 @@ import pytest import torch from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Subset +from torch.utils.data.dataset import Subset, IterableDataset import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer +from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -487,6 +488,31 @@ def test_warning_with_few_workers(tmpdir, ckpt_path): trainer.test(**test_options) +def test_warning_with_iterable_dataset_and_len(tmpdir): + """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ + model = EvalModelTemplate() + original_dataset = model.train_dataloader().dataset + class IterableWithLen(IterableDataset): + + def __iter__(self): + return iter(original_dataset) + + def __len__(self): + return len(original_dataset) + + dataloader = DataLoader(IterableWithLen(), batch_size=16) + assert _has_len(dataloader) + assert _has_iterable_dataset(dataloader) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=3, + ) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.test(model, test_dataloaders=[dataloader]) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') def test_dataloader_reinit_for_subclass(): From 189941895618811550e7b9088ef63bafe2153dbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 06:59:54 +0200 Subject: [PATCH 03/11] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0968488111d24..6d0bf1b750900 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added reduce ddp results on eval ([#2434](https://github.com/PyTorchLightning/pytorch-lightning/pull/2434)) +- Added a warning when an `IterableDataset` has `__len__` defined ([#2437](https://github.com/PyTorchLightning/pytorch-lightning/pull/2437)) + ### Changed From c88ec3e9b86890eb550a71fb3bd78640cbf172dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 07:00:36 +0200 Subject: [PATCH 04/11] pep --- tests/trainer/test_dataloaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f9a4557bee026..b61268c7c606c 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -492,6 +492,7 @@ def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ model = EvalModelTemplate() original_dataset = model.train_dataloader().dataset + class IterableWithLen(IterableDataset): def __iter__(self): From 47d8141c54ef51d933a13385893b650aa2075408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 07:18:16 +0200 Subject: [PATCH 05/11] do not show warning below 1.4 --- pytorch_lightning/trainer/data_loading.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index ef0103aaae72d..66d16533af3ae 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,9 +1,11 @@ +import multiprocessing import platform from abc import ABC, abstractmethod from typing import Union, List, Tuple, Callable, Optional -import multiprocessing +import torch import torch.distributed as torch_distrib +from packaging.version import parse from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -61,7 +63,7 @@ def _has_len(dataloader: DataLoader) -> bool: except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used has_len = False - if has_len and _has_iterable_dataset(dataloader): + if has_len and _has_iterable_dataset(dataloader) and parse(torch.__version__) >= parse("1.4.0"): rank_zero_warn( 'Your `IterableDataset` has `__len__` defined.' ' In combination with multi-processing data loading (e.g. batch size > 1),' From a4e44cf33f33242b482c0f9b1d682c61cdf5433a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 07:34:10 +0200 Subject: [PATCH 06/11] try version parse --- requirements/base.txt | 3 ++- tests/trainer/test_dataloaders.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/requirements/base.txt b/requirements/base.txt index 2b26d79033f6f..68ebafee3284d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -6,4 +6,5 @@ tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement -tqdm>=4.41.0 \ No newline at end of file +tqdm>=4.41.0 +packaging \ No newline at end of file diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index b61268c7c606c..3c1df3636448a 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -2,6 +2,7 @@ import pytest import torch +from packaging.version import parse from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Subset, IterableDataset @@ -488,6 +489,11 @@ def test_warning_with_few_workers(tmpdir, ckpt_path): trainer.test(**test_options) +@pytest.mark.xfail( + parse(torch.__version__) < parse("1.4.0"), + reason="IterableDataset with __len__ before 1.4 raises", + raises=TypeError +) def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ model = EvalModelTemplate() From 4f4449bcacc073d4cf0f3a68b331e8a47b1b6dcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 07:36:56 +0200 Subject: [PATCH 07/11] comments --- pytorch_lightning/trainer/data_loading.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 66d16533af3ae..ea5b611bc3eb4 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -144,9 +144,8 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader - # don't manipulate iterable datasets is_dataloader = isinstance(dataloader, DataLoader) - + # don't manipulate iterable datasets is_iterable_ds = _has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: From 2647c45a005a52c23b48192ae06d6e29a5948920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 1 Jul 2020 07:43:40 +0200 Subject: [PATCH 08/11] xfail --- tests/trainer/test_dataloaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 3c1df3636448a..99fe02979013f 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -492,7 +492,6 @@ def test_warning_with_few_workers(tmpdir, ckpt_path): @pytest.mark.xfail( parse(torch.__version__) < parse("1.4.0"), reason="IterableDataset with __len__ before 1.4 raises", - raises=TypeError ) def test_warning_with_iterable_dataset_and_len(tmpdir): """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ From e29f7409189db5f17056a97f8a10113760437c6f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 1 Jul 2020 06:18:17 -0400 Subject: [PATCH 09/11] Update requirements/base.txt Co-authored-by: Jirka Borovec --- requirements/base.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/base.txt b/requirements/base.txt index 68ebafee3284d..bacc868dada85 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -7,4 +7,3 @@ future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement tqdm>=4.41.0 -packaging \ No newline at end of file From bdd15fe35b4db6f5a82261c95d9cc7ae3ec7d8e2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 1 Jul 2020 06:18:41 -0400 Subject: [PATCH 10/11] Update pytorch_lightning/trainer/data_loading.py Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index ea5b611bc3eb4..f166a4461046c 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -63,7 +63,7 @@ def _has_len(dataloader: DataLoader) -> bool: except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used has_len = False - if has_len and _has_iterable_dataset(dataloader) and parse(torch.__version__) >= parse("1.4.0"): + if has_len and _has_iterable_dataset(dataloader) and version.parse(torch.__version__) >= version.parse("1.4.0"): rank_zero_warn( 'Your `IterableDataset` has `__len__` defined.' ' In combination with multi-processing data loading (e.g. batch size > 1),' From 3e5f67375d3922d1455f231309069ad3f48a0616 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 1 Jul 2020 13:32:51 +0200 Subject: [PATCH 11/11] version --- pytorch_lightning/trainer/data_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index f166a4461046c..e283166234968 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,11 +1,11 @@ import multiprocessing import platform from abc import ABC, abstractmethod +from distutils.version import LooseVersion from typing import Union, List, Tuple, Callable, Optional import torch import torch.distributed as torch_distrib -from packaging.version import parse from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -63,7 +63,7 @@ def _has_len(dataloader: DataLoader) -> bool: except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used has_len = False - if has_len and _has_iterable_dataset(dataloader) and version.parse(torch.__version__) >= version.parse("1.4.0"): + if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): rank_zero_warn( 'Your `IterableDataset` has `__len__` defined.' ' In combination with multi-processing data loading (e.g. batch size > 1),'