Skip to content

Commit bcae6b8

Browse files
Adrian Wälchlitullie
Adrian Wälchli
authored andcommitted
Fix for incorrect run on the validation set with overwritten validation_epoch_end and test_end (Lightning-AI#1353)
* reorder if clauses * fix wrong method overload in test * fix formatting * update change_log * fix line too long
1 parent d841f03 commit bcae6b8

File tree

3 files changed

+32
-22
lines changed

3 files changed

+32
-22
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6161
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
6262
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
6363
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).
64+
- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)).
6465

6566
## [0.7.1] - 2020-03-07
6667

pytorch_lightning/trainer/evaluation_loop.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -295,20 +295,25 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
295295
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
296296
model = model.module
297297

298-
# TODO: remove in v1.0.0
299-
if test_mode and self.is_overriden('test_end', model=model):
300-
eval_results = model.test_end(outputs)
301-
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
302-
' Use `test_epoch_end` instead.', DeprecationWarning)
303-
elif self.is_overriden('validation_end', model=model):
304-
eval_results = model.validation_end(outputs)
305-
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
306-
' Use `validation_epoch_end` instead.', DeprecationWarning)
307-
308-
if test_mode and self.is_overriden('test_epoch_end', model=model):
309-
eval_results = model.test_epoch_end(outputs)
310-
elif self.is_overriden('validation_epoch_end', model=model):
311-
eval_results = model.validation_epoch_end(outputs)
298+
if test_mode:
299+
if self.is_overriden('test_end', model=model):
300+
# TODO: remove in v1.0.0
301+
eval_results = model.test_end(outputs)
302+
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
303+
' Use `test_epoch_end` instead.', DeprecationWarning)
304+
305+
elif self.is_overriden('test_epoch_end', model=model):
306+
eval_results = model.test_epoch_end(outputs)
307+
308+
else:
309+
if self.is_overriden('validation_end', model=model):
310+
# TODO: remove in v1.0.0
311+
eval_results = model.validation_end(outputs)
312+
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
313+
' Use `validation_epoch_end` instead.', DeprecationWarning)
314+
315+
elif self.is_overriden('validation_epoch_end', model=model):
316+
eval_results = model.validation_epoch_end(outputs)
312317

313318
# enable train mode again
314319
model.train()

tests/trainer/test_trainer.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -528,15 +528,15 @@ def test_disabled_validation():
528528
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
529529

530530
validation_step_invoked = False
531-
validation_end_invoked = False
531+
validation_epoch_end_invoked = False
532532

533533
def validation_step(self, *args, **kwargs):
534534
self.validation_step_invoked = True
535535
return super().validation_step(*args, **kwargs)
536536

537-
def validation_end(self, *args, **kwargs):
538-
self.validation_end_invoked = True
539-
return super().validation_end(*args, **kwargs)
537+
def validation_epoch_end(self, *args, **kwargs):
538+
self.validation_epoch_end_invoked = True
539+
return super().validation_epoch_end(*args, **kwargs)
540540

541541
hparams = tutils.get_default_hparams()
542542
model = CurrentModel(hparams)
@@ -555,8 +555,10 @@ def validation_end(self, *args, **kwargs):
555555
# check that val_percent_check=0 turns off validation
556556
assert result == 1, 'training failed to complete'
557557
assert trainer.current_epoch == 1
558-
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
559-
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
558+
assert not model.validation_step_invoked, \
559+
'`validation_step` should not run when `val_percent_check=0`'
560+
assert not model.validation_epoch_end_invoked, \
561+
'`validation_epoch_end` should not run when `val_percent_check=0`'
560562

561563
# check that val_percent_check has no influence when fast_dev_run is turned on
562564
model = CurrentModel(hparams)
@@ -566,8 +568,10 @@ def validation_end(self, *args, **kwargs):
566568

567569
assert result == 1, 'training failed to complete'
568570
assert trainer.current_epoch == 0
569-
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
570-
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
571+
assert model.validation_step_invoked, \
572+
'did not run `validation_step` with `fast_dev_run=True`'
573+
assert model.validation_epoch_end_invoked, \
574+
'did not run `validation_epoch_end` with `fast_dev_run=True`'
571575

572576

573577
def test_nan_loss_detection(tmpdir):

0 commit comments

Comments
 (0)