Skip to content

Commit 7824b5c

Browse files
kuynzerebwilliamFalcon
authored andcommitted
Fix percent_checks (#649)
* fix percent_checks * Added _percent_range_check * remove max
1 parent 9ac91ad commit 7824b5c

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

pytorch_lightning/trainer/data_loading.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ def __init__(self):
3636
self.shown_warnings = None
3737
self.val_check_interval = None
3838

39+
def _percent_range_check(self, name):
40+
value = getattr(self, name)
41+
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
42+
if name == "val_check_interval":
43+
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
44+
45+
if not 0. <= value <= 1.:
46+
raise ValueError(msg)
47+
3948
def init_train_dataloader(self, model):
4049
"""
4150
Dataloaders are provided by the model
@@ -48,6 +57,8 @@ def init_train_dataloader(self, model):
4857
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset):
4958
self.num_training_batches = float('inf')
5059
else:
60+
self._percent_range_check('train_percent_check')
61+
5162
self.num_training_batches = len(self.get_train_dataloader())
5263
self.num_training_batches = int(self.num_training_batches * self.train_percent_check)
5364

@@ -56,7 +67,14 @@ def init_train_dataloader(self, model):
5667
# otherwise, it checks in [0, 1.0] % range of a training epoch
5768
if isinstance(self.val_check_interval, int):
5869
self.val_check_batch = self.val_check_interval
70+
if self.val_check_batch > self.num_training_batches:
71+
raise ValueError(
72+
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
73+
f"to the number of the training batches ({self.num_training_batches}). "
74+
f"If you want to disable validation set `val_percent_check` to 0.0 instead.")
5975
else:
76+
self._percent_range_check('val_check_interval')
77+
6078
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
6179
self.val_check_batch = max(1, self.val_check_batch)
6280

@@ -89,13 +107,15 @@ def init_val_dataloader(self, model):
89107
:return:
90108
"""
91109
self.get_val_dataloaders = model.val_dataloader
110+
self.num_val_batches = 0
92111

93112
# determine number of validation batches
94113
# val datasets could be none, 1 or 2+
95114
if self.get_val_dataloaders() is not None:
115+
self._percent_range_check('val_percent_check')
116+
96117
self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
97118
self.num_val_batches = int(self.num_val_batches * self.val_percent_check)
98-
self.num_val_batches = max(1, self.num_val_batches)
99119

100120
on_ddp = self.use_ddp or self.use_ddp2
101121
if on_ddp and self.get_val_dataloaders() is not None:
@@ -134,10 +154,11 @@ def init_test_dataloader(self, model):
134154

135155
# determine number of test batches
136156
if self.get_test_dataloaders() is not None:
157+
self._percent_range_check('test_percent_check')
158+
137159
len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
138160
self.num_test_batches = len_sum
139161
self.num_test_batches = int(self.num_test_batches * self.test_percent_check)
140-
self.num_test_batches = max(1, self.num_test_batches)
141162

142163
on_ddp = self.use_ddp or self.use_ddp2
143164
if on_ddp and self.get_test_dataloaders() is not None:
@@ -208,6 +229,10 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check,
208229
self.val_percent_check = val_percent_check
209230
self.test_percent_check = test_percent_check
210231
if overfit_pct > 0:
232+
if overfit_pct > 1:
233+
raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got "
234+
f"{overfit_pct:.3f}.")
235+
211236
self.train_percent_check = overfit_pct
212237
self.val_percent_check = overfit_pct
213238
self.test_percent_check = overfit_pct

0 commit comments

Comments
 (0)