Skip to content

Commit 1e68968

Browse files
awaelchliBordawilliamFalcon
authored
support num_sanity_val_steps=-1 (#2246)
* support sanity_val_step=-1 * fix list size * simplification * simplify * add test for num_sanity_val_steps=-1 * update test * update docs * extend tests to multiple dataloaders * changelog * Update tests/trainer/test_trainer.py Co-authored-by: Jirka Borovec <[email protected]> * improve test * refactor the sanity check decision * fix merge * Update trainer.py Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent 62ce00f commit 1e68968

File tree

8 files changed

+62
-17
lines changed

8 files changed

+62
-17
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
1313

14+
- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))
15+
1416
### Changed
1517

1618

docs/source/debugging.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,4 +129,4 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)
129129
.. testcode::
130130

131131
# DEFAULT
132-
trainer = Trainer(num_sanity_val_steps=5)
132+
trainer = Trainer(num_sanity_val_steps=2)

pytorch_lightning/callbacks/progress.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def total_val_batches(self) -> int:
102102
total_val_batches = 0
103103
if trainer.fast_dev_run and trainer.val_dataloaders is not None:
104104
total_val_batches = len(trainer.val_dataloaders)
105-
elif not self.trainer.disable_validation:
105+
elif self.trainer.enable_validation:
106106
is_val_epoch = (trainer.current_epoch + 1) % trainer.check_val_every_n_epoch == 0
107107
total_val_batches = sum(trainer.num_val_batches) if is_val_epoch else 0
108108
return total_val_batches
@@ -302,7 +302,7 @@ def init_test_tqdm(self) -> tqdm:
302302
def on_sanity_check_start(self, trainer, pl_module):
303303
super().on_sanity_check_start(trainer, pl_module)
304304
self.val_progress_bar = self.init_sanity_tqdm()
305-
self.val_progress_bar.total = trainer.num_sanity_val_steps * len(trainer.val_dataloaders)
305+
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
306306
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
307307

308308
def on_sanity_check_end(self, trainer, pl_module):

pytorch_lightning/trainer/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -603,16 +603,19 @@ def on_train_end(self, trainer, pl_module):
603603
604604
Sanity check runs n batches of val before starting the training routine.
605605
This catches any bugs in your validation without having to wait for the first validation check.
606-
The Trainer uses 5 steps by default. Turn it off or modify it here.
606+
The Trainer uses 2 steps by default. Turn it off or modify it here.
607607
608608
.. testcode::
609609
610610
# default used by the Trainer
611-
trainer = Trainer(num_sanity_val_steps=5)
611+
trainer = Trainer(num_sanity_val_steps=2)
612612
613613
# turn it off
614614
trainer = Trainer(num_sanity_val_steps=0)
615615
616+
# check all validation data
617+
trainer = Trainer(num_sanity_val_steps=-1)
618+
616619
num_tpu_cores
617620
^^^^^^^^^^^^^
618621
.. warning:: .. deprecated:: 0.7.6

pytorch_lightning/trainer/evaluation_loop.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,16 @@
7272
-----------------------------------------
7373
7474
Lightning runs a few steps of validation in the beginning of training.
75-
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
75+
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
7676
7777
.. code-block:: python
7878
7979
# DEFAULT
80-
trainer = Trainer(num_sanity_val_steps=5)
80+
trainer = Trainer(num_sanity_val_steps=2)
8181
8282
83-
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check.
83+
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)`
84+
to check all the validation data.
8485
8586
# Testing loop
8687

pytorch_lightning/trainer/trainer.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ def __init__(
336336
337337
amp_level: The optimization level to use (O1, O2, etc...).
338338
339-
num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
339+
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
340+
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
340341
341342
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
342343
@@ -408,7 +409,6 @@ def __init__(
408409
# training state
409410
self.model = None
410411
self.testing = False
411-
self.disable_validation = False
412412
self.prepare_data_per_node = prepare_data_per_node
413413
self.lr_schedulers = []
414414
self.optimizers = None
@@ -488,7 +488,7 @@ def __init__(
488488
self.max_steps = max_steps
489489
self.min_steps = min_steps
490490

491-
self.num_sanity_val_steps = num_sanity_val_steps
491+
self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
492492
# Backward compatibility, TODO: remove in v0.9.0
493493
if print_nan_grads:
494494
rank_zero_warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
@@ -883,6 +883,17 @@ def progress_bar_dict(self) -> dict:
883883
ref_model = self.model if not self.data_parallel else self.model.module
884884
return dict(**ref_model.get_progress_bar_dict(), **self.progress_bar_metrics)
885885

886+
@property
887+
def disable_validation(self) -> bool:
888+
""" Check if validation is disabled during training. """
889+
return not self.enable_validation
890+
891+
@property
892+
def enable_validation(self) -> bool:
893+
""" Check if we should run validation during training. """
894+
val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0)
895+
return val_loop_enabled or self.fast_dev_run
896+
886897
# -----------------------------
887898
# MODEL TRAINING
888899
# -----------------------------
@@ -1186,10 +1197,6 @@ def run_pretrain_routine(self, model: LightningModule):
11861197

11871198
return eval_loop_results
11881199

1189-
# check if we should run validation during training
1190-
self.disable_validation = not (self.is_overridden('validation_step') and self.limit_val_batches > 0) \
1191-
and not self.fast_dev_run
1192-
11931200
# run a few val batches before training starts
11941201
self._run_sanity_check(ref_model, model)
11951202

@@ -1204,9 +1211,12 @@ def run_pretrain_routine(self, model: LightningModule):
12041211
self.train()
12051212

12061213
def _run_sanity_check(self, ref_model, model):
1214+
should_sanity_check = self.is_overridden('validation_step') and self.num_sanity_val_steps > 0 \
1215+
and self.limit_val_batches > 0
1216+
12071217
# run tiny validation (if validation defined)
12081218
# to make sure program won't crash during val
1209-
if not self.disable_validation and self.num_sanity_val_steps > 0:
1219+
if should_sanity_check:
12101220
self.reset_val_dataloader(ref_model)
12111221

12121222
# hook and callback

pytorch_lightning/trainer/training_loop.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def should_check_val(self, batch_idx, is_last_batch):
651651
# decide if we should run validation
652652
is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0
653653
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
654-
can_check_val = not self.disable_validation and can_check_epoch
654+
can_check_val = self.enable_validation and can_check_epoch
655655
should_check_val = is_val_check_batch or self.should_stop
656656
is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf'))
657657
should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset)

tests/trainer/test_trainer.py

+29
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import types
77
from argparse import Namespace
88
from pathlib import Path
9+
from unittest.mock import patch
910

1011
import cloudpickle
1112
import pytest
@@ -807,6 +808,34 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
807808
assert trainer.tpu_id == expected_tpu_id
808809

809810

811+
@pytest.mark.parametrize(['limit_val_batches'], [
812+
pytest.param(0.0), # this should run no sanity checks
813+
pytest.param(1),
814+
pytest.param(1.0),
815+
pytest.param(0.3),
816+
])
817+
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
818+
"""
819+
Test that num_sanity_val_steps=-1 runs through all validation data once.
820+
Makes sure this setting is independent of limit_val_batches.
821+
"""
822+
model = EvalModelTemplate()
823+
model.validation_step = model.validation_step__multiple_dataloaders
824+
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
825+
trainer = Trainer(
826+
default_root_dir=tmpdir,
827+
num_sanity_val_steps=-1,
828+
limit_val_batches=limit_val_batches, # should have no influence
829+
max_steps=1,
830+
)
831+
assert trainer.num_sanity_val_steps == float('inf')
832+
val_dataloaders = model.val_dataloader__multiple()
833+
834+
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
835+
trainer.fit(model, val_dataloaders=val_dataloaders)
836+
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)
837+
838+
810839
@pytest.mark.parametrize("trainer_kwargs,expected", [
811840
pytest.param(
812841
dict(distributed_backend=None, gpus=None),

0 commit comments

Comments
 (0)