Skip to content

Commit 7cca385

Browse files
rohitgr7awaelchliAdrian Wälchliananyahjha93
authored
Fix num_sanity_val_steps is clipped to limit_val_batches (#2917)
* Fix num_sanity_val_steps according to limit_val_steps * fix test * add num_sanity_batches * pep * update docstring in test * add more test * chlog * update comments and docstring in test Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Ananya Harsh Jha <[email protected]>
1 parent bcdb750 commit 7cca385

File tree

4 files changed

+44
-13
lines changed

4 files changed

+44
-13
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Fixed
2323

24+
- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))
2425

2526
## [0.9.0] - YYYY-MM-DD
2627

@@ -121,7 +122,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
121122
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
122123
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))
123124

124-
125125
## [0.8.5] - 2020-07-09
126126

127127
### Added

pytorch_lightning/callbacks/progress.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def init_test_tqdm(self) -> tqdm:
307307
def on_sanity_check_start(self, trainer, pl_module):
308308
super().on_sanity_check_start(trainer, pl_module)
309309
self.val_progress_bar = self.init_sanity_tqdm()
310-
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
310+
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
311311
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
312312

313313
def on_sanity_check_end(self, trainer, pl_module):

pytorch_lightning/trainer/trainer.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def __init__(
377377
self.logged_metrics = {}
378378
self.num_training_batches = 0
379379
self.num_val_batches = []
380+
self.num_sanity_val_batches = []
380381
self.num_test_batches = []
381382
self.train_dataloader = None
382383
self.test_dataloaders = None
@@ -463,9 +464,9 @@ def __init__(
463464
self.min_steps = min_steps
464465

465466
if num_sanity_val_steps == -1:
466-
self.num_sanity_val_steps = float("inf")
467+
self.num_sanity_val_steps = float('inf')
467468
else:
468-
self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches)
469+
self.num_sanity_val_steps = num_sanity_val_steps
469470

470471
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
471472

@@ -1239,22 +1240,22 @@ def run_pretrain_routine(self, model: LightningModule):
12391240
self.train()
12401241

12411242
def _run_sanity_check(self, ref_model, model):
1242-
12431243
using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
12441244
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
12451245

12461246
# run tiny validation (if validation defined)
12471247
# to make sure program won't crash during val
12481248
if should_sanity_check:
12491249
self.reset_val_dataloader(ref_model)
1250+
self.num_sanity_val_batches = [
1251+
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
1252+
]
12501253

12511254
# hook and callback
12521255
self.running_sanity_check = True
12531256
self.on_sanity_check_start()
12541257

1255-
num_loaders = len(self.val_dataloaders)
1256-
max_batches = [self.num_sanity_val_steps] * num_loaders
1257-
eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False)
1258+
eval_results = self._evaluate(model, self.val_dataloaders, self.num_sanity_val_batches, False)
12581259

12591260
# allow no returns from eval
12601261
if eval_results is not None and len(eval_results) > 0:

tests/trainer/test_trainer.py

+35-5
Original file line numberDiff line numberDiff line change
@@ -907,28 +907,58 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
907907
pytest.param(0.0), # this should run no sanity checks
908908
pytest.param(1),
909909
pytest.param(1.0),
910-
pytest.param(0.3),
910+
pytest.param(0.5),
911+
pytest.param(5),
911912
])
912913
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
914+
""" Test that the number of sanity check batches is clipped to limit_val_batches. """
915+
model = EvalModelTemplate()
916+
model.validation_step = model.validation_step__multiple_dataloaders
917+
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
918+
num_sanity_val_steps = 4
919+
920+
trainer = Trainer(
921+
default_root_dir=tmpdir,
922+
num_sanity_val_steps=num_sanity_val_steps,
923+
limit_val_batches=limit_val_batches,
924+
max_steps=1,
925+
)
926+
assert trainer.num_sanity_val_steps == num_sanity_val_steps
927+
val_dataloaders = model.val_dataloader__multiple_mixed_length()
928+
929+
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
930+
trainer.fit(model, val_dataloaders=val_dataloaders)
931+
assert mocked.call_count == sum(
932+
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
933+
)
934+
935+
936+
@pytest.mark.parametrize(['limit_val_batches'], [
937+
pytest.param(0.0), # this should run no sanity checks
938+
pytest.param(1),
939+
pytest.param(1.0),
940+
pytest.param(0.3),
941+
])
942+
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
913943
"""
914-
Test that num_sanity_val_steps=-1 runs through all validation data once.
915-
Makes sure this setting is independent of limit_val_batches.
944+
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
945+
limited by "limit_val_batches" Trainer argument.
916946
"""
917947
model = EvalModelTemplate()
918948
model.validation_step = model.validation_step__multiple_dataloaders
919949
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
920950
trainer = Trainer(
921951
default_root_dir=tmpdir,
922952
num_sanity_val_steps=-1,
923-
limit_val_batches=limit_val_batches, # should have no influence
953+
limit_val_batches=limit_val_batches,
924954
max_steps=1,
925955
)
926956
assert trainer.num_sanity_val_steps == float('inf')
927957
val_dataloaders = model.val_dataloader__multiple()
928958

929959
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
930960
trainer.fit(model, val_dataloaders=val_dataloaders)
931-
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)
961+
assert mocked.call_count == sum(trainer.num_val_batches)
932962

933963

934964
@pytest.mark.parametrize("trainer_kwargs,expected", [

0 commit comments

Comments
 (0)