@@ -36,6 +36,15 @@ def __init__(self):
36
36
self .shown_warnings = None
37
37
self .val_check_interval = None
38
38
39
+ def _percent_range_check (self , name ):
40
+ value = getattr (self , name )
41
+ msg = f"`{ name } ` must lie in the range [0.0, 1.0], but got { value :.3f} ."
42
+ if name == "val_check_interval" :
43
+ msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
44
+
45
+ if not 0. <= value <= 1. :
46
+ raise ValueError (msg )
47
+
39
48
def init_train_dataloader (self , model ):
40
49
"""
41
50
Dataloaders are provided by the model
@@ -48,6 +57,8 @@ def init_train_dataloader(self, model):
48
57
if EXIST_ITER_DATASET and isinstance (self .get_train_dataloader ().dataset , IterableDataset ):
49
58
self .num_training_batches = float ('inf' )
50
59
else :
60
+ self ._percent_range_check ('train_percent_check' )
61
+
51
62
self .num_training_batches = len (self .get_train_dataloader ())
52
63
self .num_training_batches = int (self .num_training_batches * self .train_percent_check )
53
64
@@ -56,7 +67,14 @@ def init_train_dataloader(self, model):
56
67
# otherwise, it checks in [0, 1.0] % range of a training epoch
57
68
if isinstance (self .val_check_interval , int ):
58
69
self .val_check_batch = self .val_check_interval
70
+ if self .val_check_batch > self .num_training_batches :
71
+ raise ValueError (
72
+ f"`val_check_interval` ({ self .val_check_interval } ) must be less than or equal "
73
+ f"to the number of the training batches ({ self .num_training_batches } ). "
74
+ f"If you want to disable validation set `val_percent_check` to 0.0 instead." )
59
75
else :
76
+ self ._percent_range_check ('val_check_interval' )
77
+
60
78
self .val_check_batch = int (self .num_training_batches * self .val_check_interval )
61
79
self .val_check_batch = max (1 , self .val_check_batch )
62
80
@@ -89,13 +107,15 @@ def init_val_dataloader(self, model):
89
107
:return:
90
108
"""
91
109
self .get_val_dataloaders = model .val_dataloader
110
+ self .num_val_batches = 0
92
111
93
112
# determine number of validation batches
94
113
# val datasets could be none, 1 or 2+
95
114
if self .get_val_dataloaders () is not None :
115
+ self ._percent_range_check ('val_percent_check' )
116
+
96
117
self .num_val_batches = sum (len (dataloader ) for dataloader in self .get_val_dataloaders ())
97
118
self .num_val_batches = int (self .num_val_batches * self .val_percent_check )
98
- self .num_val_batches = max (1 , self .num_val_batches )
99
119
100
120
on_ddp = self .use_ddp or self .use_ddp2
101
121
if on_ddp and self .get_val_dataloaders () is not None :
@@ -134,10 +154,11 @@ def init_test_dataloader(self, model):
134
154
135
155
# determine number of test batches
136
156
if self .get_test_dataloaders () is not None :
157
+ self ._percent_range_check ('test_percent_check' )
158
+
137
159
len_sum = sum (len (dataloader ) for dataloader in self .get_test_dataloaders ())
138
160
self .num_test_batches = len_sum
139
161
self .num_test_batches = int (self .num_test_batches * self .test_percent_check )
140
- self .num_test_batches = max (1 , self .num_test_batches )
141
162
142
163
on_ddp = self .use_ddp or self .use_ddp2
143
164
if on_ddp and self .get_test_dataloaders () is not None :
@@ -208,6 +229,10 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check,
208
229
self .val_percent_check = val_percent_check
209
230
self .test_percent_check = test_percent_check
210
231
if overfit_pct > 0 :
232
+ if overfit_pct > 1 :
233
+ raise ValueError (f"`overfit_pct` must be not greater than 1.0, but got "
234
+ f"{ overfit_pct :.3f} ." )
235
+
211
236
self .train_percent_check = overfit_pct
212
237
self .val_percent_check = overfit_pct
213
238
self .test_percent_check = overfit_pct
0 commit comments