Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.optim.lr_scheduler.ReduceLROnPlateau #320

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler
from .pt_callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateauScheduler, GradientAccumulationScheduler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like a relative import which we shall not use... :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relative imports of EarlyStopping, ModelCheckpoint etc. are taken from the original repository. Why is relative import of ReduceLROnPlateauScheduler inappropriate?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#402 never fixed relative imports in callbacks init as well as in many other places. I'd say that above comment is out of scope for this PR. @Borda it might be better to create a separate PR that will properly fix relative imports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was not that the PR was fixing relative imports, but I tried to make them which was stopped...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do a separate PR for relative imports.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have opened ticket #459


__all__ = [
'EarlyStopping',
'ModelCheckpoint',
'ReduceLROnPlateauScheduler',
'GradientAccumulationScheduler',
]
30 changes: 30 additions & 0 deletions pytorch_lightning/callbacks/pt_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,36 @@ def on_train_end(self, logs=None):
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))


class ReduceLROnPlateauScheduler(Callback):
"""
Reduce learning rate when the monitored metric has stopped improving.
Wrapper for torch.optim.lr_schuduler.ReduceLROnPlateau learning rate
schedulers.

# Arguments
schedulers: list of torch.optim.lr_scheduler.ReduceLROnPlateau
monitor: quantity to be monitored.
"""

def __init__(self, schedulers, monitor='val_loss'):
super(ReduceLROnPlateauScheduler, self).__init__()

self.monitor = monitor
self.schedulers = schedulers

def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
stop_training = False
if current is None:
print('ReduceLROnPlateau conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning)
exit(-1)

for scheduler in self.schedulers:
scheduler.step(current, epoch=epoch)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to create our own ReduceLROnPlateauScheduler?
We should be operating directly on the PyTorch one (https://pytorch.org/docs/stable/optim.html?highlight=reducelr#torch.optim.lr_scheduler.ReduceLROnPlateau)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReduceLROnPlateauScheduler.schedulers is a list of orginal torch.optim.lr_scheduler.ReduceLROnPlateau, see the proof in a comment below

class ModelCheckpoint(Callback):
"""Save the model after every epoch.
`filepath` can contain named formatting options,
Expand Down
Loading