diff --git a/pytorch_lightning/trainer/train_loop_mixin.py b/pytorch_lightning/trainer/train_loop_mixin.py index 0935a3cf62789..5ee7ae06890ba 100644 --- a/pytorch_lightning/trainer/train_loop_mixin.py +++ b/pytorch_lightning/trainer/train_loop_mixin.py @@ -150,7 +150,8 @@ def training_step(self, batch, batch_nb): """ import numpy as np -import tqdm + +from pytorch_lightning.utilities.debugging import MisconfigurationException try: from apex import amp @@ -213,7 +214,15 @@ def train(self): # update LR schedulers if self.lr_schedulers is not None: for lr_scheduler in self.lr_schedulers: - lr_scheduler.step(self.current_epoch) + lr_scheduler.step(epoch=self.current_epoch) + if self.reduce_lr_on_plateau_scheduler is not None: + val_loss = self.callback_metrics.get('val_loss') + if val_loss is None: + avail_metrics = ','.join(list(self.callback_metrics.keys())) + m = f'ReduceLROnPlateau conditioned on metric val_loss ' \ + f'which is not available. Available metrics are: {avail_metrics}' + raise MisconfigurationException(m) + self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch) # early stopping met_min_epochs = epoch_nb > self.min_nb_epochs diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fc34085972e03..b783e4406aac0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -186,6 +186,8 @@ def __init__(self, self.early_stop_callback = None self.configure_early_stopping(early_stop_callback, logger) + self.reduce_lr_on_plateau_scheduler = None + # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path @@ -376,12 +378,20 @@ def init_optimizers(self, optimizers): # two lists elif len(optimizers) == 2 and isinstance(optimizers[0], list): optimizers, lr_schedulers = optimizers + lr_schedulers, self.reduce_lr_on_plateau_scheduler = self.configure_schedulers(lr_schedulers) return optimizers, lr_schedulers # single list or tuple elif isinstance(optimizers, list) or isinstance(optimizers, tuple): return optimizers, [] + def configure_schedulers(self, schedulers): + for i, scheduler in enumerate(schedulers): + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + reduce_lr_on_plateau_scheduler = schedulers.pop(i) + return schedulers, reduce_lr_on_plateau_scheduler + return schedulers, None + def run_pretrain_routine(self, model): """ Sanity check a few things before starting actual training