Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.optim.lr_scheduler.ReduceLROnPlateau #320

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions pytorch_lightning/trainer/train_loop_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def training_step(self, batch, batch_nb):
"""

import numpy as np
import tqdm

from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
from apex import amp
Expand Down Expand Up @@ -213,7 +214,15 @@ def train(self):
# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step(self.current_epoch)
lr_scheduler.step(epoch=self.current_epoch)
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems very specific. does it only need to work with val_loss?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be any validation metric in theory, but how do we let the user pass it in? Perhaps via a dedicated dict entry in the validation_end output similar to "log" and "progress_bar"? It's one more thing the user needs to remember, but maybe its fine since this lr_scheduler is optional and it should be mentioned in the docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good option. let’s do it in a separate PR?

if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)

# early stopping
met_min_epochs = epoch_nb > self.min_nb_epochs
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def __init__(self,
self.early_stop_callback = None
self.configure_early_stopping(early_stop_callback, logger)

self.reduce_lr_on_plateau_scheduler = None

# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
self.weights_save_path = weights_save_path
Expand Down Expand Up @@ -376,12 +378,20 @@ def init_optimizers(self, optimizers):
# two lists
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
optimizers, lr_schedulers = optimizers
lr_schedulers, self.reduce_lr_on_plateau_scheduler = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers

# single list or tuple
elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
return optimizers, []

def configure_schedulers(self, schedulers):
for i, scheduler in enumerate(schedulers):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
reduce_lr_on_plateau_scheduler = schedulers.pop(i)
return schedulers, reduce_lr_on_plateau_scheduler
return schedulers, None

def run_pretrain_routine(self, model):
"""
Sanity check a few things before starting actual training
Expand Down