Skip to content

Commit bb68760

Browse files
alexeykarnachevtullie
authored andcommitted
Fixed configure optimizer from dict without "scheduler" key (Lightning-AI#1443)
* `configure_optimizer` from dict with only "optimizer" key. bug fixed * autopep8 * pep8speaks suggested fixes * CHANGELOG.md upd
1 parent 46a9825 commit bb68760

File tree

5 files changed

+29
-1
lines changed

5 files changed

+29
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212

1313
### Fixed
1414

15+
- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443))
1516
- Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425))
1617
- Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430))
1718
- Fixed returning tuple from `run_training_batch` ([#1431](https://github.com/PyTorchLightning/pytorch-lightning/pull/1431))

pytorch_lightning/trainer/optimizers.py

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def init_optimizers(
3939
lr_scheduler = optim_conf.get("lr_scheduler", [])
4040
if lr_scheduler:
4141
lr_schedulers = self.configure_schedulers([lr_scheduler])
42+
else:
43+
lr_schedulers = []
4244
return [optimizer], lr_schedulers, []
4345

4446
# multiple dictionaries

pytorch_lightning/trainer/supporters.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class TensorRunningAccum(object):
2020
>>> accum.last(), accum.mean(), accum.min(), accum.max()
2121
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
2222
"""
23+
2324
def __init__(self, window_length: int):
2425
self.window_length = window_length
2526
self.memory = torch.Tensor(self.window_length)

pytorch_lightning/trainer/trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,8 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
554554
if at[0] not in depr_arg_names):
555555
for allowed_type in (at for at in allowed_types if at in arg_types):
556556
if isinstance(allowed_type, bool):
557-
allowed_type = lambda x: bool(distutils.util.strtobool(x))
557+
def allowed_type(x):
558+
return bool(distutils.util.strtobool(x))
558559
parser.add_argument(
559560
f'--{arg}',
560561
default=arg_default,

tests/trainer/test_optimizers.py

+23
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,26 @@ class CurrentTestModel(
275275

276276
# verify training completed
277277
assert result == 1
278+
279+
280+
def test_configure_optimizer_from_dict(tmpdir):
281+
"""Tests if `configure_optimizer` method could return a dictionary with
282+
`optimizer` field only.
283+
"""
284+
285+
class CurrentTestModel(LightTrainDataloader, TestModelBase):
286+
def configure_optimizers(self):
287+
config = {
288+
'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03)
289+
}
290+
return config
291+
292+
hparams = tutils.get_default_hparams()
293+
model = CurrentTestModel(hparams)
294+
295+
trainer_options = dict(default_save_path=tmpdir, max_epochs=1)
296+
297+
# fit model
298+
trainer = Trainer(**trainer_options)
299+
result = trainer.fit(model)
300+
assert result == 1

0 commit comments

Comments
 (0)