Skip to content

Commit 384e124

Browse files
SkafteNickiNicki Skafte
and
Nicki Skafte
authored
ReduceLROnPlateau bug fix (#1126)
* bug fix and test * update CHANGELOG.md Co-authored-by: Nicki Skafte <[email protected]>
1 parent 774d9be commit 384e124

File tree

5 files changed

+51
-5
lines changed

5 files changed

+51
-5
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
### Fixed
3030

31-
-
31+
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
3232

3333
## [0.7.1] - 2020-03-07
3434

pytorch_lightning/trainer/trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -707,8 +707,8 @@ def configure_schedulers(self, schedulers: list):
707707
if 'scheduler' not in scheduler:
708708
raise ValueError(f'Lr scheduler should have key `scheduler`',
709709
' with item being a lr scheduler')
710-
scheduler['reduce_on_plateau'] = \
711-
isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau)
710+
scheduler['reduce_on_plateau'] = isinstance(
711+
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)
712712

713713
lr_schedulers.append({**default_config, **scheduler})
714714

tests/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
LightInfTestDataloader,
2525
LightTestOptimizerWithSchedulingMixin,
2626
LightTestMultipleOptimizersWithSchedulingMixin,
27-
LightTestOptimizersWithMixedSchedulingMixin
27+
LightTestOptimizersWithMixedSchedulingMixin,
28+
LightTestReduceLROnPlateauMixin
2829
)
2930

3031

tests/models/mixins.py

+10
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,16 @@ def configure_optimizers(self):
678678
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]
679679

680680

681+
class LightTestReduceLROnPlateauMixin:
682+
def configure_optimizers(self):
683+
if self.hparams.optimizer_name == 'lbfgs':
684+
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
685+
else:
686+
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
687+
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
688+
return [optimizer], [lr_scheduler]
689+
690+
681691
def _get_output_metric(output, name):
682692
if isinstance(output, dict):
683693
val = output[name]

tests/trainer/test_optimizers.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
from tests.models import (
1111
TestModelBase,
1212
LightTrainDataloader,
13+
LightValidationStepMixin,
14+
LightValidationMixin,
1315
LightTestOptimizerWithSchedulingMixin,
1416
LightTestMultipleOptimizersWithSchedulingMixin,
15-
LightTestOptimizersWithMixedSchedulingMixin
17+
LightTestOptimizersWithMixedSchedulingMixin,
18+
LightTestReduceLROnPlateauMixin
1619
)
1720

1821

@@ -144,3 +147,35 @@ class CurrentTestModel(
144147
# Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times
145148
assert init_lr * 0.1 == adjusted_lr2, \
146149
'lr for optimizer 2 not adjusted correctly'
150+
151+
152+
def test_reduce_lr_on_plateau_scheduling(tmpdir):
153+
tutils.reset_seed()
154+
155+
class CurrentTestModel(
156+
LightTestReduceLROnPlateauMixin,
157+
LightTrainDataloader,
158+
LightValidationMixin,
159+
LightValidationStepMixin,
160+
TestModelBase):
161+
pass
162+
163+
hparams = tutils.get_hparams()
164+
model = CurrentTestModel(hparams)
165+
166+
# logger file to get meta
167+
trainer_options = dict(
168+
default_save_path=tmpdir,
169+
max_epochs=1,
170+
val_percent_check=0.1,
171+
train_percent_check=0.2
172+
)
173+
174+
# fit model
175+
trainer = Trainer(**trainer_options)
176+
results = trainer.fit(model)
177+
178+
assert trainer.lr_schedulers[0] == \
179+
dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss',
180+
interval='epoch', frequency=1, reduce_on_plateau=True), \
181+
'lr schduler was not correctly converted to dict'

0 commit comments

Comments
 (0)