From 0e591aab0845511f2bbc2d070b0e1d5935a74f30 Mon Sep 17 00:00:00 2001 From: manipopopo Date: Sun, 9 Aug 2020 09:57:43 +0000 Subject: [PATCH 1/9] Fix the progress bar for the sanity check The original progress bar will always show trainer.num_sanity_val_steps even if the length of the validation DataLoader is less than trainer.num_sanity_val_steps. The pytorch_lightning.trainer.data_loading._has_len is changed to a public function has_len, which is called by pytorch_lightning/callbacks/progress.py --- pytorch_lightning/callbacks/progress.py | 10 ++++- pytorch_lightning/trainer/data_loading.py | 8 ++-- tests/callbacks/test_progress_bar.py | 47 +++++++++++++++++++++++ tests/trainer/test_dataloaders.py | 4 +- 4 files changed, 62 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index c3cd9137c9ed7..ffdf90310f534 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -293,7 +293,15 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders)) + + # There will be circular imports if we move the following import to the + # top of the file. + from pytorch_lightning.trainer.data_loading import has_len + + self.val_progress_bar.total = sum( + min(trainer.num_sanity_val_steps, + len(d) if has_len(d) else float('inf')) + for d in trainer.val_dataloaders) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 575e28354b5de..a4f239f379ed0 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -46,7 +46,7 @@ def _has_iterable_dataset(dataloader: DataLoader): and isinstance(dataloader.dataset, IterableDataset) -def _has_len(dataloader: DataLoader) -> bool: +def has_len(dataloader: DataLoader) -> bool: """ Checks if a given Dataloader has __len__ method implemented i.e. if it is a finite dataloader or infinite dataloader. """ @@ -195,7 +195,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) - self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf') + self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: @@ -219,7 +219,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.') else: - if not _has_len(self.train_dataloader): + if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: @@ -282,7 +282,7 @@ def _reset_eval_dataloader( # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - num_batches = len(dataloader) if _has_len(dataloader) else float('inf') + num_batches = len(dataloader) if has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 713bdf3c3c2c4..276810d7c29ed 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -193,3 +193,50 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx trainer.test(model) assert progress_bar.test_batches_seen == progress_bar.total_test_batches + + +@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches', [ + (-1, [10]), + (0, [10]), + (2, [10]), + (10, [2]), + (10, [2, 3]), + (10, [20, 3]), + (10, [20, 30]), + (10, [float('inf')]), + (10, [1, float('inf')]), +]) +def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, + num_val_dataloaders_batches): + """Test that the sanity_check progress finishes with the correct total steps processed.""" + + tmp_model = EvalModelTemplate(batch_size=1) + batch_size = len(tmp_model.dataloader(train=False, num_samples=1).dataset) + model = EvalModelTemplate(batch_size=batch_size) + + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=1, + limit_val_batches=(len(model.dataloader(train=False)) * + len(num_val_dataloaders_batches)), + max_epochs=0, + num_sanity_val_steps=num_sanity_val_steps, + ) + + val_dataloaders = [] + for num_samples in num_val_dataloaders_batches: + if num_samples == float('inf'): + val_dataloaders.append(model.val_dataloader__infinite()) + else: + val_dataloaders.append( + model.dataloader(train=False, num_samples=num_samples)) + trainer.fit(model, val_dataloaders=val_dataloaders) + + # check val progress bar total is the number of steps sanity check runs + max_sanity_val_steps = (float('inf') if num_sanity_val_steps == -1 else + num_sanity_val_steps) + num_sanity_check_run_steps = sum( + min(max_sanity_val_steps, num_val_dataloader_batches) + for num_val_dataloader_batches in num_val_dataloaders_batches) + val_progress_bar = trainer.progress_bar_callback.val_progress_bar + assert getattr(val_progress_bar, 'total', 0) == num_sanity_check_run_steps diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index d9e2500707fc8..0d58aea53698d 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -11,7 +11,7 @@ import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, Callback -from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len +from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -637,7 +637,7 @@ def __len__(self): return len(original_dataset) dataloader = DataLoader(IterableWithLen(), batch_size=16) - assert _has_len(dataloader) + assert has_len(dataloader) assert _has_iterable_dataset(dataloader) trainer = Trainer( default_root_dir=tmpdir, From 0a9342a5add1297208bd555f11c3a1c869e23564 Mon Sep 17 00:00:00 2001 From: manipopopo Date: Sun, 9 Aug 2020 10:25:09 +0000 Subject: [PATCH 2/9] Fix W504 line break after binary operator --- tests/callbacks/test_progress_bar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 276810d7c29ed..9804be9f00450 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -214,11 +214,11 @@ def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, batch_size = len(tmp_model.dataloader(train=False, num_samples=1).dataset) model = EvalModelTemplate(batch_size=batch_size) + num_dataloaders = len(num_val_dataloaders_batches) trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=1, - limit_val_batches=(len(model.dataloader(train=False)) * - len(num_val_dataloaders_batches)), + limit_val_batches=len(model.dataloader(train=False)) * num_dataloaders, max_epochs=0, num_sanity_val_steps=num_sanity_val_steps, ) From f87074f07a06222e900e8440cdf21cf450c99b2a Mon Sep 17 00:00:00 2001 From: manipopopo Date: Tue, 11 Aug 2020 13:01:49 +0000 Subject: [PATCH 3/9] Move functions to pytorch_lightning.utilities.data --- pytorch_lightning/callbacks/progress.py | 11 ++---- pytorch_lightning/trainer/data_loading.py | 40 ++-------------------- pytorch_lightning/utilities/data.py | 41 +++++++++++++++++++++++ tests/trainer/test_dataloaders.py | 34 +------------------ tests/utilities/test_data.py | 39 +++++++++++++++++++++ 5 files changed, 86 insertions(+), 79 deletions(-) create mode 100644 pytorch_lightning/utilities/data.py create mode 100644 tests/utilities/test_data.py diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index ffdf90310f534..0475e2ec2cd1e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -17,6 +17,7 @@ from tqdm import tqdm from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.data import has_len class ProgressBarBase(Callback): @@ -293,15 +294,9 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - - # There will be circular imports if we move the following import to the - # top of the file. - from pytorch_lightning.trainer.data_loading import has_len - self.val_progress_bar.total = sum( - min(trainer.num_sanity_val_steps, - len(d) if has_len(d) else float('inf')) - for d in trainer.val_dataloaders) + min(trainer.num_sanity_val_steps, len(d) if has_len(d) else float('inf')) for d in trainer.val_dataloaders + ) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a4f239f379ed0..8f3bd54d8a888 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -1,24 +1,17 @@ import multiprocessing import platform from abc import ABC, abstractmethod -from distutils.version import LooseVersion from typing import Union, List, Tuple, Callable, Optional -import torch import torch.distributed as torch_distrib from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException -try: - from torch.utils.data import IterableDataset - ITERABLE_DATASET_EXISTS = True -except ImportError: - ITERABLE_DATASET_EXISTS = False - try: from apex import amp except ImportError: @@ -41,35 +34,6 @@ HOROVOD_AVAILABLE = True -def _has_iterable_dataset(dataloader: DataLoader): - return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ - and isinstance(dataloader.dataset, IterableDataset) - - -def has_len(dataloader: DataLoader) -> bool: - """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader. """ - - try: - # try getting the length - if len(dataloader) == 0: - raise ValueError('`Dataloader` returned 0 length.' - ' Please make sure that your Dataloader at least returns 1 batch') - has_len = True - except TypeError: - has_len = False - except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - has_len = False - - if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): - rank_zero_warn( - 'Your `IterableDataset` has `__len__` defined.' - ' In combination with multi-processing data loading (e.g. batch size > 1),' - ' this can lead to unintended side effects since the samples will be duplicated.' - ) - return has_len - - class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, @@ -131,7 +95,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets - is_iterable_ds = _has_iterable_dataset(dataloader) + is_iterable_ds = has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py new file mode 100644 index 0000000000000..7569d68f00a9f --- /dev/null +++ b/pytorch_lightning/utilities/data.py @@ -0,0 +1,41 @@ +from distutils.version import LooseVersion + +import torch +from torch.utils.data import DataLoader + +from pytorch_lightning.utilities import rank_zero_warn + +try: + from torch.utils.data import IterableDataset + ITERABLE_DATASET_EXISTS = True +except ImportError: + ITERABLE_DATASET_EXISTS = False + + +def has_iterable_dataset(dataloader: DataLoader): + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ + and isinstance(dataloader.dataset, IterableDataset) + + +def has_len(dataloader: DataLoader) -> bool: + """ Checks if a given Dataloader has __len__ method implemented i.e. if + it is a finite dataloader or infinite dataloader. """ + + try: + # try getting the length + if len(dataloader) == 0: + raise ValueError('`Dataloader` returned 0 length.' + ' Please make sure that your Dataloader at least returns 1 batch') + has_len = True + except TypeError: + has_len = False + except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used + has_len = False + + if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 0d58aea53698d..52abb16478bca 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -4,14 +4,12 @@ import pytest import torch -from packaging.version import parse from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import IterableDataset, Subset +from torch.utils.data.dataset import Subset from torch.utils.data.distributed import DistributedSampler import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, Callback -from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -619,36 +617,6 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path): trainer.test(**test_options) -@pytest.mark.xfail( - parse(torch.__version__) < parse("1.4.0"), - reason="IterableDataset with __len__ before 1.4 raises", -) -def test_warning_with_iterable_dataset_and_len(tmpdir): - """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ - model = EvalModelTemplate() - original_dataset = model.train_dataloader().dataset - - class IterableWithLen(IterableDataset): - - def __iter__(self): - return iter(original_dataset) - - def __len__(self): - return len(original_dataset) - - dataloader = DataLoader(IterableWithLen(), batch_size=16) - assert has_len(dataloader) - assert _has_iterable_dataset(dataloader) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=3, - ) - with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): - trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) - with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): - trainer.test(model, test_dataloaders=[dataloader]) - - @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs') def test_dataloader_reinit_for_subclass(tmpdir): diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py new file mode 100644 index 0000000000000..31b5ee5ceb566 --- /dev/null +++ b/tests/utilities/test_data.py @@ -0,0 +1,39 @@ +import pytest +import torch +from packaging.version import parse +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import IterableDataset + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len +from tests.base import EvalModelTemplate + + +@pytest.mark.xfail( + parse(torch.__version__) < parse("1.4.0"), + reason="IterableDataset with __len__ before 1.4 raises", +) +def test_warning_with_iterable_dataset_and_len(tmpdir): + """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ + model = EvalModelTemplate() + original_dataset = model.train_dataloader().dataset + + class IterableWithLen(IterableDataset): + + def __iter__(self): + return iter(original_dataset) + + def __len__(self): + return len(original_dataset) + + dataloader = DataLoader(IterableWithLen(), batch_size=16) + assert has_len(dataloader) + assert has_iterable_dataset(dataloader) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=3, + ) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader]) + with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'): + trainer.test(model, test_dataloaders=[dataloader]) From 8b1fde44a08c88645a787df874a7b02c60185e78 Mon Sep 17 00:00:00 2001 From: manipopopo Date: Tue, 11 Aug 2020 13:48:20 +0000 Subject: [PATCH 4/9] Simplify test cases --- tests/callbacks/test_progress_bar.py | 29 +++++++++++----------------- 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 9804be9f00450..96745e08ad8b1 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -195,19 +195,18 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx assert progress_bar.test_batches_seen == progress_bar.total_test_batches -@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches', [ - (-1, [10]), - (0, [10]), - (2, [10]), - (10, [2]), - (10, [2, 3]), - (10, [20, 3]), - (10, [20, 30]), - (10, [float('inf')]), - (10, [1, float('inf')]), +@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,num_sanity_check_run_steps', [ + (-1, [10], 10), + (0, [10], 0), + (2, [10], 2), + (10, [2], 2), + (10, [2, 3], 5), + (10, [20, 3], 13), + (10, [20, 30], 20), + (10, [float('inf')], 10), + (10, [1, float('inf')], 11), ]) -def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, - num_val_dataloaders_batches): +def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, num_sanity_check_run_steps): """Test that the sanity_check progress finishes with the correct total steps processed.""" tmp_model = EvalModelTemplate(batch_size=1) @@ -232,11 +231,5 @@ def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, model.dataloader(train=False, num_samples=num_samples)) trainer.fit(model, val_dataloaders=val_dataloaders) - # check val progress bar total is the number of steps sanity check runs - max_sanity_val_steps = (float('inf') if num_sanity_val_steps == -1 else - num_sanity_val_steps) - num_sanity_check_run_steps = sum( - min(max_sanity_val_steps, num_val_dataloader_batches) - for num_val_dataloader_batches in num_val_dataloaders_batches) val_progress_bar = trainer.progress_bar_callback.val_progress_bar assert getattr(val_progress_bar, 'total', 0) == num_sanity_check_run_steps From a6be8090b585f8abb114d4b3a624f59d6b4708c9 Mon Sep 17 00:00:00 2001 From: manipopopo Date: Tue, 11 Aug 2020 13:48:35 +0000 Subject: [PATCH 5/9] Update CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e639a5c53f6f..5390509b8c3db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,6 +138,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed checkpointing to remote file paths ([#2925](https://github.com/PyTorchLightning/pytorch-lightning/pull/2925)) +- Fixed the total steps of the progress bar for the validation sanity check ([#2892](https://github.com/PyTorchLightning/pytorch-lightning/pull/2892)) + ## [0.8.5] - 2020-07-09 ### Added From 8041aa9b72324062fbcbda33d0a39e0d39a72051 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Aug 2020 16:13:56 +0200 Subject: [PATCH 6/9] rename --- tests/callbacks/test_progress_bar.py | 6 +++--- tests/utilities/test_data.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 96745e08ad8b1..347716a8eeb0e 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -195,7 +195,7 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx assert progress_bar.test_batches_seen == progress_bar.total_test_batches -@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,num_sanity_check_run_steps', [ +@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,expected_num_steps', [ (-1, [10], 10), (0, [10], 0), (2, [10], 2), @@ -206,7 +206,7 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx (10, [float('inf')], 10), (10, [1, float('inf')], 11), ]) -def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, num_sanity_check_run_steps): +def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, expected_num_steps): """Test that the sanity_check progress finishes with the correct total steps processed.""" tmp_model = EvalModelTemplate(batch_size=1) @@ -232,4 +232,4 @@ def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_d trainer.fit(model, val_dataloaders=val_dataloaders) val_progress_bar = trainer.progress_bar_callback.val_progress_bar - assert getattr(val_progress_bar, 'total', 0) == num_sanity_check_run_steps + assert getattr(val_progress_bar, 'total', 0) == expected_num_steps diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 31b5ee5ceb566..b0c31b68a561b 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -14,7 +14,7 @@ reason="IterableDataset with __len__ before 1.4 raises", ) def test_warning_with_iterable_dataset_and_len(tmpdir): - """ Tests that a warning messages is shown when an IterableDataset defines `__len__`. """ + """ Tests that a warning message is shown when an IterableDataset defines `__len__`. """ model = EvalModelTemplate() original_dataset = model.train_dataloader().dataset From 117c1fe16f73437895cbac8bf4aebaa8b40b7e67 Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Tue, 11 Aug 2020 11:11:10 -0400 Subject: [PATCH 7/9] removed pep8 issue --- tests/callbacks/test_progress_bar.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 347716a8eeb0e..023b5927699eb 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -206,7 +206,9 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx (10, [float('inf')], 10), (10, [1, float('inf')], 11), ]) -def test_sanity_check_progress_bar_total(tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, expected_num_steps): +def test_sanity_check_progress_bar_total( + tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, expected_num_steps +): """Test that the sanity_check progress finishes with the correct total steps processed.""" tmp_model = EvalModelTemplate(batch_size=1) From 5987308b7a6f07704a4e38d75c4405d86a99268e Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 13 Aug 2020 20:15:38 -0400 Subject: [PATCH 8/9] doc fix --- docs/source/debugging.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 56c96607bcc4b..df0b241293d23 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -1,7 +1,7 @@ .. testsetup:: * from pytorch_lightning.trainer.trainer import Trainer - + .. _debugging: Debugging From e978a580dd37bc827cbd112af46eafafc388e04c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Fri, 14 Aug 2020 13:35:41 -0400 Subject: [PATCH 9/9] doc --- docs/source/debugging.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index df0b241293d23..19c75acae16e1 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -2,7 +2,7 @@ from pytorch_lightning.trainer.trainer import Trainer - .. _debugging: +.. _debugging: Debugging =========