@@ -1055,9 +1055,6 @@ def test(
1055
1055
else :
1056
1056
self .__attach_dataloaders (self .model , test_dataloaders = test_dataloaders )
1057
1057
1058
- # give proper warnings if user only passed in loader without hooks
1059
- self .check_testing_model_configuration (model if model else self .model )
1060
-
1061
1058
if model is not None :
1062
1059
self .model = model
1063
1060
self .fit (model )
@@ -1076,44 +1073,45 @@ def test(
1076
1073
1077
1074
def check_model_configuration (self , model : LightningModule ):
1078
1075
r"""
1079
- Checks that the model is configured correctly before training is started.
1076
+ Checks that the model is configured correctly before training or testing is started.
1080
1077
1081
1078
Args:
1082
- model: The model to test .
1079
+ model: The model to check the configuration .
1083
1080
1084
1081
"""
1085
1082
# Check training_step, train_dataloader, configure_optimizer methods
1086
- if not self .is_overridden ('training_step' , model ):
1087
- raise MisconfigurationException (
1088
- 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
1089
- ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' )
1090
-
1091
- if not self .is_overridden ('train_dataloader' , model ):
1092
- raise MisconfigurationException (
1093
- 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
1094
- ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' )
1095
-
1096
- if not self .is_overridden ('configure_optimizers' , model ):
1097
- raise MisconfigurationException (
1098
- 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
1099
- ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' )
1100
-
1101
- # Check val_dataloader, validation_step and validation_epoch_end
1102
- if self .is_overridden ('val_dataloader' , model ):
1103
- if not self .is_overridden ('validation_step' , model ):
1104
- raise MisconfigurationException ('You have passed in a `val_dataloader()`'
1105
- ' but have not defined `validation_step()`.' )
1083
+ if not self .testing :
1084
+ if not self .is_overridden ('training_step' , model ):
1085
+ raise MisconfigurationException (
1086
+ 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
1087
+ ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.' )
1088
+
1089
+ if not self .is_overridden ('train_dataloader' , model ):
1090
+ raise MisconfigurationException (
1091
+ 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
1092
+ ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.' )
1093
+
1094
+ if not self .is_overridden ('configure_optimizers' , model ):
1095
+ raise MisconfigurationException (
1096
+ 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
1097
+ ' `training_step()`, `training_dataloader()` and `configure_optimizers()` to be defined.' )
1098
+
1099
+ # Check val_dataloader, validation_step and validation_epoch_end
1100
+ if self .is_overridden ('val_dataloader' , model ):
1101
+ if not self .is_overridden ('validation_step' , model ):
1102
+ raise MisconfigurationException ('You have passed in a `val_dataloader()`'
1103
+ ' but have not defined `validation_step()`.' )
1104
+ else :
1105
+ if not self .is_overridden ('validation_epoch_end' , model ):
1106
+ rank_zero_warn (
1107
+ 'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
1108
+ ' you may also want to define `validation_epoch_end()` for accumulating stats.' ,
1109
+ RuntimeWarning
1110
+ )
1106
1111
else :
1107
- if not self .is_overridden ('validation_epoch_end' , model ):
1108
- rank_zero_warn (
1109
- 'You have defined a `val_dataloader()` and have defined a `validation_step()`,'
1110
- ' you may also want to define `validation_epoch_end()` for accumulating stats.' ,
1111
- RuntimeWarning
1112
- )
1113
- else :
1114
- if self .is_overridden ('validation_step' , model ):
1115
- raise MisconfigurationException ('You have defined `validation_step()`,'
1116
- ' but have not passed in a val_dataloader().' )
1112
+ if self .is_overridden ('validation_step' , model ):
1113
+ raise MisconfigurationException ('You have defined `validation_step()`,'
1114
+ ' but have not passed in a `val_dataloader()`.' )
1117
1115
1118
1116
# Check test_dataloader, test_step and test_epoch_end
1119
1117
if self .is_overridden ('test_dataloader' , model ):
@@ -1126,25 +1124,10 @@ def check_model_configuration(self, model: LightningModule):
1126
1124
'You have defined a `test_dataloader()` and have defined a `test_step()`, you may also want to'
1127
1125
' define `test_epoch_end()` for accumulating stats.' , RuntimeWarning
1128
1126
)
1129
-
1130
- def check_testing_model_configuration (self , model : LightningModule ):
1131
-
1132
- has_test_step = self .is_overridden ('test_step' , model )
1133
- has_test_epoch_end = self .is_overridden ('test_epoch_end' , model )
1134
- gave_test_loader = self .is_overridden ('test_dataloader' , model )
1135
-
1136
- if gave_test_loader and not has_test_step :
1137
- raise MisconfigurationException ('You passed in a `test_dataloader` but did not implement `test_step()`' )
1138
-
1139
- if has_test_step and not gave_test_loader :
1140
- raise MisconfigurationException ('You defined `test_step()` but did not implement'
1141
- ' `test_dataloader` nor passed in `.fit(test_dataloaders`.' )
1142
-
1143
- if has_test_step and gave_test_loader and not has_test_epoch_end :
1144
- rank_zero_warn (
1145
- 'You passed in a `test_dataloader` and have defined a `test_step()`, you may also want to'
1146
- ' define `test_epoch_end()` for accumulating stats.' , RuntimeWarning
1147
- )
1127
+ else :
1128
+ if self .testing and self .is_overridden ('test_step' , model ):
1129
+ raise MisconfigurationException ('You have defined `test_step()` but did not'
1130
+ ' implement `test_dataloader` nor passed in `.test(test_dataloader)`.' )
1148
1131
1149
1132
1150
1133
class _PatchDataLoader (object ):
0 commit comments