-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch.optim.lr_scheduler.ReduceLROnPlateau #320
Changes from 4 commits
1a173d8
acb30c0
6827262
b6f827b
fa4094e
d232a7c
94a285d
26844ac
cae83ca
29c5480
cd946a3
78b975b
4f5ae52
c9d5618
75faae5
5465466
d16a6c0
7650fa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler | ||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateauScheduler, GradientAccumulationScheduler | ||
|
||
__all__ = [ | ||
'EarlyStopping', | ||
'ModelCheckpoint', | ||
'ReduceLROnPlateauScheduler', | ||
'GradientAccumulationScheduler', | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,6 +144,36 @@ def on_train_end(self, logs=None): | |
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) | ||
|
||
|
||
class ReduceLROnPlateauScheduler(Callback): | ||
""" | ||
Reduce learning rate when the monitored metric has stopped improving. | ||
Wrapper for torch.optim.lr_schuduler.ReduceLROnPlateau learning rate | ||
schedulers. | ||
|
||
# Arguments | ||
schedulers: list of torch.optim.lr_scheduler.ReduceLROnPlateau | ||
monitor: quantity to be monitored. | ||
""" | ||
|
||
def __init__(self, schedulers, monitor='val_loss'): | ||
super(ReduceLROnPlateauScheduler, self).__init__() | ||
|
||
self.monitor = monitor | ||
self.schedulers = schedulers | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
current = logs.get(self.monitor) | ||
stop_training = False | ||
if current is None: | ||
print('ReduceLROnPlateau conditioned on metric `%s` ' | ||
'which is not available. Available metrics are: %s' % | ||
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning) | ||
exit(-1) | ||
|
||
for scheduler in self.schedulers: | ||
scheduler.step(current, epoch=epoch) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to create our own ReduceLROnPlateauScheduler? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
class ModelCheckpoint(Callback): | ||
"""Save the model after every epoch. | ||
`filepath` can contain named formatting options, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like a relative import which we shall not use... :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relative imports of
EarlyStopping
,ModelCheckpoint
etc. are taken from the original repository. Why is relative import ofReduceLROnPlateauScheduler
inappropriate?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#402 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#402 never fixed relative imports in callbacks init as well as in many other places. I'd say that above comment is out of scope for this PR. @Borda it might be better to create a separate PR that will properly fix relative imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was not that the PR was fixing relative imports, but I tried to make them which was stopped...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do a separate PR for relative imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have opened ticket #459