diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 84413697948d5..33bd99fcabc94 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -36,6 +36,15 @@ def __init__(self): self.shown_warnings = None self.val_check_interval = None + def _percent_range_check(self, name): + value = getattr(self, name) + msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}." + if name == "val_check_interval": + msg += " If you want to disable validation set `val_percent_check` to 0.0 instead." + + if not 0. <= value <= 1.: + raise ValueError(msg) + def init_train_dataloader(self, model): """ Dataloaders are provided by the model @@ -48,6 +57,8 @@ def init_train_dataloader(self, model): if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset): self.num_training_batches = float('inf') else: + self._percent_range_check('train_percent_check') + self.num_training_batches = len(self.get_train_dataloader()) self.num_training_batches = int(self.num_training_batches * self.train_percent_check) @@ -56,7 +67,14 @@ def init_train_dataloader(self, model): # otherwise, it checks in [0, 1.0] % range of a training epoch if isinstance(self.val_check_interval, int): self.val_check_batch = self.val_check_interval + if self.val_check_batch > self.num_training_batches: + raise ValueError( + f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " + f"to the number of the training batches ({self.num_training_batches}). " + f"If you want to disable validation set `val_percent_check` to 0.0 instead.") else: + self._percent_range_check('val_check_interval') + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) @@ -89,13 +107,15 @@ def init_val_dataloader(self, model): :return: """ self.get_val_dataloaders = model.val_dataloader + self.num_val_batches = 0 # determine number of validation batches # val datasets could be none, 1 or 2+ if self.get_val_dataloaders() is not None: + self._percent_range_check('val_percent_check') + self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) self.num_val_batches = int(self.num_val_batches * self.val_percent_check) - self.num_val_batches = max(1, self.num_val_batches) on_ddp = self.use_ddp or self.use_ddp2 if on_ddp and self.get_val_dataloaders() is not None: @@ -134,10 +154,11 @@ def init_test_dataloader(self, model): # determine number of test batches if self.get_test_dataloaders() is not None: + self._percent_range_check('test_percent_check') + len_sum = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) self.num_test_batches = len_sum self.num_test_batches = int(self.num_test_batches * self.test_percent_check) - self.num_test_batches = max(1, self.num_test_batches) on_ddp = self.use_ddp or self.use_ddp2 if on_ddp and self.get_test_dataloaders() is not None: @@ -208,6 +229,10 @@ def determine_data_use_amount(self, train_percent_check, val_percent_check, self.val_percent_check = val_percent_check self.test_percent_check = test_percent_check if overfit_pct > 0: + if overfit_pct > 1: + raise ValueError(f"`overfit_pct` must be not greater than 1.0, but got " + f"{overfit_pct:.3f}.") + self.train_percent_check = overfit_pct self.val_percent_check = overfit_pct self.test_percent_check = overfit_pct