Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReduceLROnPlateau bug fix #1126

Merged
merged 3 commits into from
Mar 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,8 @@ def configure_schedulers(self, schedulers: list):
if 'scheduler' not in scheduler:
raise ValueError(f'Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
scheduler['reduce_on_plateau'] = \
isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau)
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)

lr_schedulers.append({**default_config, **scheduler})

Expand Down
3 changes: 2 additions & 1 deletion tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
LightTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
LightTestOptimizersWithMixedSchedulingMixin,
LightTestReduceLROnPlateauMixin
)


Expand Down
10 changes: 10 additions & 0 deletions tests/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,16 @@ def configure_optimizers(self):
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]


class LightTestReduceLROnPlateauMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
return [optimizer], [lr_scheduler]


def _get_output_metric(output, name):
if isinstance(output, dict):
val = output[name]
Expand Down
37 changes: 36 additions & 1 deletion tests/trainer/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from tests.models import (
TestModelBase,
LightTrainDataloader,
LightValidationStepMixin,
LightValidationMixin,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
LightTestOptimizersWithMixedSchedulingMixin,
LightTestReduceLROnPlateauMixin
)


Expand Down Expand Up @@ -144,3 +147,35 @@ class CurrentTestModel(
# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times
assert init_lr * 0.1 == adjusted_lr2, \
'lr for optimizer 2 not adjusted correctly'


def test_reduce_lr_on_plateau_scheduling(tmpdir):
tutils.reset_seed()

class CurrentTestModel(
LightTestReduceLROnPlateauMixin,
LightTrainDataloader,
LightValidationMixin,
LightValidationStepMixin,
TestModelBase):
pass

hparams = tutils.get_hparams()
model = CurrentTestModel(hparams)

# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2
)

# fit model
trainer = Trainer(**trainer_options)
results = trainer.fit(model)

assert trainer.lr_schedulers[0] == \
dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss',
interval='epoch', frequency=1, reduce_on_plateau=True), \
'lr schduler was not correctly converted to dict'