Skip to content

Commit 56d521a

Browse files
authored
Fix test configuration check and testing (#1804)
* Fix test configuration check and testing * Fix test configuration check and testing * Remove check_testing_configuration during test * Fix docstring * fix function name * remove conflicts
1 parent 4cdebf9 commit 56d521a

File tree

2 files changed

+38
-62
lines changed

2 files changed

+38
-62
lines changed

pytorch_lightning/trainer/evaluation_loop.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -334,20 +334,13 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
334334
return eval_results
335335

336336
def run_evaluation(self, test_mode: bool = False):
337-
# when testing make sure user defined a test step
338-
if test_mode and not self.is_overridden('test_step'):
339-
raise MisconfigurationException(
340-
"You called `.test()` without defining model's `.test_step()`."
341-
" Please define and try again")
342-
343337
# hook
344338
model = self.get_model()
345339
model.on_pre_performance_check()
346340

347341
# select dataloaders
348342
if test_mode:
349-
if self.test_dataloaders is None:
350-
self.reset_test_dataloader(model)
343+
self.reset_test_dataloader(model)
351344

352345
dataloaders = self.test_dataloaders
353346
max_batches = self.num_test_batches

pytorch_lightning/trainer/trainer.py

+37-54
Original file line numberDiff line numberDiff line change
@@ -1055,9 +1055,6 @@ def test(
10551055
else:
10561056
self.__attach_dataloaders(self.model, test_dataloaders=test_dataloaders)
10571057

1058-
# give proper warnings if user only passed in loader without hooks
1059-
self.check_testing_model_configuration(model if model else self.model)
1060-
10611058
if model is not None:
10621059
self.model = model
10631060
self.fit(model)
@@ -1076,44 +1073,45 @@ def test(
10761073

10771074
def check_model_configuration(self, model: LightningModule):
10781075
r"""
1079-
Checks that the model is configured correctly before training is started.
1076+
Checks that the model is configured correctly before training or testing is started.
10801077
10811078
Args:
1082-
model: The model to test.
1079+
model: The model to check the configuration.
10831080
10841081
"""
10851082
# Check training_step, train_dataloader, configure_optimizer methods
1086-
if not self.is_overridden('training_step', model):
1087-
raise MisconfigurationException(
1088-
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
1089-
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.')
1090-
1091-
if not self.is_overridden('train_dataloader', model):
1092-
raise MisconfigurationException(
1093-
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
1094-
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.')
1095-
1096-
if not self.is_overridden('configure_optimizers', model):
1097-
raise MisconfigurationException(
1098-
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
1099-
' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.')
1100-
1101-
# Check val_dataloader, validation_step and validation_epoch_end
1102-
if self.is_overridden('val_dataloader', model):
1103-
if not self.is_overridden('validation_step', model):
1104-
raise MisconfigurationException('You have passed in a `val_dataloader()`'
1105-
' but have not defined `validation_step()`.')
1083+
if not self.testing:
1084+
if not self.is_overridden('training_step', model):
1085+
raise MisconfigurationException(
1086+
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
1087+
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
1088+
1089+
if not self.is_overridden('train_dataloader', model):
1090+
raise MisconfigurationException(
1091+
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
1092+
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
1093+
1094+
if not self.is_overridden('configure_optimizers', model):
1095+
raise MisconfigurationException(
1096+
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
1097+
' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.')
1098+
1099+
# Check val_dataloader, validation_step and validation_epoch_end
1100+
if self.is_overridden('val_dataloader', model):
1101+
if not self.is_overridden('validation_step', model):
1102+
raise MisconfigurationException('You have passed in a `val_dataloader()`'
1103+
' but have not defined `validation_step()`.')
1104+
else:
1105+
if not self.is_overridden('validation_epoch_end', model):
1106+
rank_zero_warn(
1107+
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
1108+
' you may also want to define `validation_epoch_end()` for accumulating stats.',
1109+
RuntimeWarning
1110+
)
11061111
else:
1107-
if not self.is_overridden('validation_epoch_end', model):
1108-
rank_zero_warn(
1109-
'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
1110-
' you may also want to define `validation_epoch_end()` for accumulating stats.',
1111-
RuntimeWarning
1112-
)
1113-
else:
1114-
if self.is_overridden('validation_step', model):
1115-
raise MisconfigurationException('You have defined `validation_step()`,'
1116-
' but have not passed in a val_dataloader().')
1112+
if self.is_overridden('validation_step', model):
1113+
raise MisconfigurationException('You have defined `validation_step()`,'
1114+
' but have not passed in a `val_dataloader()`.')
11171115

11181116
# Check test_dataloader, test_step and test_epoch_end
11191117
if self.is_overridden('test_dataloader', model):
@@ -1126,25 +1124,10 @@ def check_model_configuration(self, model: LightningModule):
11261124
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
11271125
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
11281126
)
1129-
1130-
def check_testing_model_configuration(self, model: LightningModule):
1131-
1132-
has_test_step = self.is_overridden('test_step', model)
1133-
has_test_epoch_end = self.is_overridden('test_epoch_end', model)
1134-
gave_test_loader = self.is_overridden('test_dataloader', model)
1135-
1136-
if gave_test_loader and not has_test_step:
1137-
raise MisconfigurationException('You passed in a `test_dataloader` but did not implement `test_step()`')
1138-
1139-
if has_test_step and not gave_test_loader:
1140-
raise MisconfigurationException('You defined `test_step()` but did not implement'
1141-
' `test_dataloader` nor passed in `.fit(test_dataloaders`.')
1142-
1143-
if has_test_step and gave_test_loader and not has_test_epoch_end:
1144-
rank_zero_warn(
1145-
'You passed in a `test_dataloader` and have defined a `test_step()`, you may also want to'
1146-
' define `test_epoch_end()` for accumulating stats.', RuntimeWarning
1147-
)
1127+
else:
1128+
if self.testing and self.is_overridden('test_step', model):
1129+
raise MisconfigurationException('You have defined `test_step()` but did not'
1130+
' implement `test_dataloader` nor passed in `.test(test_dataloader)`.')
11481131

11491132

11501133
class _PatchDataLoader(object):

0 commit comments

Comments
 (0)