diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e639a5c53f6f..5390509b8c3db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,6 +138,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed checkpointing to remote file paths ([#2925](https://github.com/PyTorchLightning/pytorch-lightning/pull/2925)) +- Fixed the total steps of the progress bar for the validation sanity check ([#2892](https://github.com/PyTorchLightning/pytorch-lightning/pull/2892)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index c3cd9137c9ed7..0475e2ec2cd1e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -17,6 +17,7 @@ from tqdm import tqdm from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.data import has_len class ProgressBarBase(Callback): @@ -293,7 +294,9 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders)) + self.val_progress_bar.total = sum( + min(trainer.num_sanity_val_steps, len(d) if has_len(d) else float('inf')) for d in trainer.val_dataloaders + ) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 575e28354b5de..8f3bd54d8a888 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,24 +1,17 @@ import multiprocessing import platform from abc import ABC, abstractmethod -from distutils.version import LooseVersion from typing import Union, List, Tuple, Callable, Optional -import torch import torch.distributed as torch_distrib from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException -try: - from torch.utils.data import IterableDataset - ITERABLE_DATASET_EXISTS = True -except ImportError: - ITERABLE_DATASET_EXISTS = False - try: from apex import amp except ImportError: @@ -41,35 +34,6 @@ HOROVOD_AVAILABLE = True -def _has_iterable_dataset(dataloader: DataLoader): - return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ - and isinstance(dataloader.dataset, IterableDataset) - - -def _has_len(dataloader: DataLoader) -> bool: - """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader. """ - - try: - # try getting the length - if len(dataloader) == 0: - raise ValueError('`Dataloader` returned 0 length.' - ' Please make sure that your Dataloader at least returns 1 batch') - has_len = True - except TypeError: - has_len = False - except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - has_len = False - - if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): - rank_zero_warn( - 'Your `IterableDataset` has `__len__` defined.' - ' In combination with multi-processing data loading (e.g. batch size > 1),' - ' this can lead to unintended side effects since the samples will be duplicated.' - ) - return has_len - - class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, @@ -131,7 +95,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets - is_iterable_ds = _has_iterable_dataset(dataloader) + is_iterable_ds = has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader @@ -195,7 +159,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) - self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf') + self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: @@ -219,7 +183,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.') else: - if not _has_len(self.train_dataloader): + if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: @@ -282,7 +246,7 @@ def _reset_eval_dataloader( # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - num_batches = len(dataloader) if _has_len(dataloader) else float('inf') + num_batches = len(dataloader) if has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py new file mode 100644 index 0000000000000..7569d68f00a9f --- /dev/null +++ b/pytorch_lightning/utilities/data.py @@ -0,0 +1,41 @@ +from distutils.version import LooseVersion + +import torch +from torch.utils.data import DataLoader + +from pytorch_lightning.utilities import rank_zero_warn + +try: + from torch.utils.data import IterableDataset + ITERABLE_DATASET_EXISTS = True +except ImportError: + ITERABLE_DATASET_EXISTS = False + + +def has_iterable_dataset(dataloader: DataLoader): + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ + and isinstance(dataloader.dataset, IterableDataset) + + +def has_len(dataloader: DataLoader) -> bool: + """ Checks if a given Dataloader has __len__ method implemented i.e. if + it is a finite dataloader or infinite dataloader. """ + + try: + # try getting the length + if len(dataloader) == 0: + raise ValueError('`Dataloader` returned 0 length.' + ' Please make sure that your Dataloader at least returns 1 batch') + has_len = True + except TypeError: + has_len = False + except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used + has_len = False + + if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 713bdf3c3c2c4..023b5927699eb 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -193,3 +193,45 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx trainer.test(model) assert progress_bar.test_batches_seen == progress_bar.total_test_batches + + +@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,expected_num_steps', [ + (-1, [10], 10), + (0, [10], 0), + (2, [10], 2), + (10, [2], 2), + (10, [2, 3], 5), + (10, [20, 3], 13), + (10, [20, 30], 20), + (10, [float('inf')], 10), + (10, [1, float('inf')], 11), +]) +def test_sanity_check_progress_bar_total( + tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, expected_num_steps +): + """Test that the sanity_check progress finishes with the correct total steps processed.""" + + tmp_model = EvalModelTemplate(batch_size=1) + batch_size = len(tmp_model.dataloader(train=False, num_samples=1).dataset) + model = EvalModelTemplate(batch_size=batch_size) + + num_dataloaders = len(num_val_dataloaders_batches) + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=1, + limit_val_batches=len(model.dataloader(train=False)) * num_dataloaders, + max_epochs=0, + num_sanity_val_steps=num_sanity_val_steps, + ) + + val_dataloaders = [] + for num_samples in num_val_dataloaders_batches: + if num_samples == float('inf'): + val_dataloaders.append(model.val_dataloader__infinite()) + else: + val_dataloaders.append( + model.dataloader(train=False, num_samples=num_samples)) + trainer.fit(model, val_dataloaders=val_dataloaders) + + val_progress_bar = trainer.progress_bar_callback.val_progress_bar + assert getattr(val_progress_bar, 'total', 0) == expected_num_steps diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d9e2500707fc8..52abb16478bca 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -4,14 +4,12 @@ import pytest import torch -from packaging.version import parse from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data.dataset import Subset from torch.utils.data.distributed import DistributedSampler import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, Callback -from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -619,36 +617,6 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): trainer.test(**test_options) -@pytest.mark.xfail( - parse(torch.__version__) < parse("1.4.0"), - reason="IterableDataset with __len__ before 1.4 raises", -) -def test_warning_with_iterable_dataset_and_len(tmpdir): - """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ - model = EvalModelTemplate() - original_dataset = model.train_dataloader().dataset - - class IterableWithLen(IterableDataset): - - def __iter__(self): - return iter(original_dataset) - - def __len__(self): - return len(original_dataset) - - dataloader = DataLoader(IterableWithLen(), batch_size=16) - assert _has_len(dataloader) - assert _has_iterable_dataset(dataloader) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=3, - ) - with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): - trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) - with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): - trainer.test(model, test_dataloaders=[dataloader]) - - @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') def test_dataloader_reinit_for_subclass(tmpdir): diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py new file mode 100644 index 0000000000000..b0c31b68a561b --- /dev/null +++ b/tests/utilities/test_data.py @@ -0,0 +1,39 @@ +import pytest +import torch +from packaging.version import parse +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import IterableDataset + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len +from tests.base import EvalModelTemplate + + +@pytest.mark.xfail( + parse(torch.__version__) < parse("1.4.0"), + reason="IterableDataset with __len__ before 1.4 raises", +) +def test_warning_with_iterable_dataset_and_len(tmpdir): + """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ + model = EvalModelTemplate() + original_dataset = model.train_dataloader().dataset + + class IterableWithLen(IterableDataset): + + def __iter__(self): + return iter(original_dataset) + + def __len__(self): + return len(original_dataset) + + dataloader = DataLoader(IterableWithLen(), batch_size=16) + assert has_len(dataloader) + assert has_iterable_dataset(dataloader) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=3, + ) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.test(model, test_dataloaders=[dataloader])