Skip to content

Commit 6b667b1

Browse files
Fix/test pass overrides (#918)
* Fix test requiring both test_step and test_end * Add test Co-authored-by: William Falcon <[email protected]>
1 parent 2b5293d commit 6b667b1

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

pytorch_lightning/trainer/evaluation_loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,8 @@ def evaluate(self, model, dataloaders, max_batches, test=False):
301301

302302
def run_evaluation(self, test=False):
303303
# when testing make sure user defined a test step
304-
if test and not (self.is_overriden('test_step')):
305-
m = '''You called `.test()` without defining model's `.test_step()`.
304+
if test and not (self.is_overriden('test_step') or self.is_overriden('test_end')):
305+
m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`.
306306
Please define and try again'''
307307
raise MisconfigurationException(m)
308308

tests/test_trainer.py

+34
Original file line numberDiff line numberDiff line change
@@ -782,5 +782,39 @@ def test_trainer_min_steps_and_epochs(tmpdir):
782782
trainer.current_epoch > 0, "Model did not train for at least min_steps"
783783

784784

785+
def test_testpass_overrides(tmpdir):
786+
hparams = tutils.get_hparams()
787+
from pytorch_lightning.utilities.debugging import MisconfigurationException
788+
789+
class TestModelNoEnd(LightningTestModelBase):
790+
def test_step(self, *args, **kwargs):
791+
return {}
792+
793+
def test_dataloader(self):
794+
return self.train_dataloader()
795+
796+
class TestModelNoStep(LightningTestModelBase):
797+
def test_end(self, outputs):
798+
return {}
799+
800+
def test_dataloader(self):
801+
return self.train_dataloader()
802+
803+
# Misconfig when neither test_step or test_end is implemented
804+
with pytest.raises(MisconfigurationException):
805+
model = LightningTestModelBase(hparams)
806+
Trainer().test(model)
807+
808+
# No exceptions when one or both of test_step or test_end are implemented
809+
model = TestModelNoStep(hparams)
810+
Trainer().test(model)
811+
812+
model = TestModelNoEnd(hparams)
813+
Trainer().test(model)
814+
815+
model = LightningTestModel(hparams)
816+
Trainer().test(model)
817+
818+
785819
# if __name__ == '__main__':
786820
# pytest.main([__file__])

0 commit comments

Comments
 (0)