@@ -336,7 +336,8 @@ def __init__(
336
336
337
337
amp_level: The optimization level to use (O1, O2, etc...).
338
338
339
- num_sanity_val_steps: Sanity check runs n batches of val before starting the training routine.
339
+ num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
340
+ Set it to `-1` to run all batches in all validation dataloaders. Default: 2
340
341
341
342
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
342
343
@@ -408,7 +409,6 @@ def __init__(
408
409
# training state
409
410
self .model = None
410
411
self .testing = False
411
- self .disable_validation = False
412
412
self .prepare_data_per_node = prepare_data_per_node
413
413
self .lr_schedulers = []
414
414
self .optimizers = None
@@ -488,7 +488,7 @@ def __init__(
488
488
self .max_steps = max_steps
489
489
self .min_steps = min_steps
490
490
491
- self .num_sanity_val_steps = num_sanity_val_steps
491
+ self .num_sanity_val_steps = float ( "inf" ) if num_sanity_val_steps == - 1 else num_sanity_val_steps
492
492
# Backward compatibility, TODO: remove in v0.9.0
493
493
if print_nan_grads :
494
494
rank_zero_warn ("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
@@ -883,6 +883,17 @@ def progress_bar_dict(self) -> dict:
883
883
ref_model = self .model if not self .data_parallel else self .model .module
884
884
return dict (** ref_model .get_progress_bar_dict (), ** self .progress_bar_metrics )
885
885
886
+ @property
887
+ def disable_validation (self ) -> bool :
888
+ """ Check if validation is disabled during training. """
889
+ return not self .enable_validation
890
+
891
+ @property
892
+ def enable_validation (self ) -> bool :
893
+ """ Check if we should run validation during training. """
894
+ val_loop_enabled = (self .is_overridden ('validation_step' ) and self .limit_val_batches > 0 )
895
+ return val_loop_enabled or self .fast_dev_run
896
+
886
897
# -----------------------------
887
898
# MODEL TRAINING
888
899
# -----------------------------
@@ -1186,10 +1197,6 @@ def run_pretrain_routine(self, model: LightningModule):
1186
1197
1187
1198
return eval_loop_results
1188
1199
1189
- # check if we should run validation during training
1190
- self .disable_validation = not (self .is_overridden ('validation_step' ) and self .limit_val_batches > 0 ) \
1191
- and not self .fast_dev_run
1192
-
1193
1200
# run a few val batches before training starts
1194
1201
self ._run_sanity_check (ref_model , model )
1195
1202
@@ -1204,9 +1211,12 @@ def run_pretrain_routine(self, model: LightningModule):
1204
1211
self .train ()
1205
1212
1206
1213
def _run_sanity_check (self , ref_model , model ):
1214
+ should_sanity_check = self .is_overridden ('validation_step' ) and self .num_sanity_val_steps > 0 \
1215
+ and self .limit_val_batches > 0
1216
+
1207
1217
# run tiny validation (if validation defined)
1208
1218
# to make sure program won't crash during val
1209
- if not self . disable_validation and self . num_sanity_val_steps > 0 :
1219
+ if should_sanity_check :
1210
1220
self .reset_val_dataloader (ref_model )
1211
1221
1212
1222
# hook and callback
0 commit comments