Skip to content

Commit d962ab5

Browse files
authored
Fix lr key name in case of param groups (#1719)
* Fix lr key name in case of param groups * Add tests * Update test and added configure_optimizers__param_groups * Update CHANGELOG
1 parent 7f64ad7 commit d962ab5

File tree

5 files changed

+38
-2
lines changed

5 files changed

+38
-2
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4444

4545
- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))
4646

47+
- Fixed lr key name in case of param groups in LearningRateLogger ([#1719](https://github.com/PyTorchLightning/pytorch-lightning/pull/1719))
48+
4749
## [0.7.5] - 2020-04-27
4850

4951
### Changed

pytorch_lightning/callbacks/lr_logger.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _extract_lr(self, trainer, interval):
8080
param_groups = scheduler['scheduler'].optimizer.param_groups
8181
if len(param_groups) != 1:
8282
for i, pg in enumerate(param_groups):
83-
lr, key = pg['lr'], f'{name}/{i + 1}'
83+
lr, key = pg['lr'], f'{name}/pg{i + 1}'
8484
self.lrs[key].append(lr)
8585
latest_stat[key] = lr
8686
else:
@@ -109,7 +109,7 @@ def _find_names(self, lr_schedulers):
109109
param_groups = sch.optimizer.param_groups
110110
if len(param_groups) != 1:
111111
for i, pg in enumerate(param_groups):
112-
temp = name + '/pg' + str(i + 1)
112+
temp = f'{name}/pg{i + 1}'
113113
names.append(temp)
114114
else:
115115
names.append(name)

tests/base/mixins.py

Whitespace-only changes.

tests/base/model_optimizers.py

+10
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,13 @@ def configure_optimizers__reduce_lr_on_plateau(self):
5959
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
6060
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
6161
return [optimizer], [lr_scheduler]
62+
63+
def configure_optimizers__param_groups(self):
64+
param_groups = [
65+
{'params': list(self.parameters())[:2], 'lr': self.hparams.learning_rate * 0.1},
66+
{'params': list(self.parameters())[2:], 'lr': self.hparams.learning_rate}
67+
]
68+
69+
optimizer = optim.Adam(param_groups)
70+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
71+
return [optimizer], [lr_scheduler]

tests/callbacks/test_callbacks.py

+24
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,27 @@ def test_lr_logger_multi_lrs(tmpdir):
331331
'Number of learning rates logged does not match number of lr schedulers'
332332
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
333333
'Names of learning rates not set correctly'
334+
335+
336+
def test_lr_logger_param_groups(tmpdir):
337+
""" Test that learning rates are extracted and logged for single lr scheduler"""
338+
tutils.reset_seed()
339+
340+
model = EvalModelTemplate()
341+
model.configure_optimizers = model.configure_optimizers__param_groups
342+
343+
lr_logger = LearningRateLogger()
344+
trainer = Trainer(
345+
default_root_dir=tmpdir,
346+
max_epochs=5,
347+
val_percent_check=0.1,
348+
train_percent_check=0.5,
349+
callbacks=[lr_logger]
350+
)
351+
results = trainer.fit(model)
352+
353+
assert lr_logger.lrs, 'No learning rates logged'
354+
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
355+
'Number of learning rates logged does not match number of param groups'
356+
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
357+
'Names of learning rates not set correctly'

0 commit comments

Comments
 (0)