Skip to content

Commit 3f09b32

Browse files
SkafteNickiNicki SkaftewilliamFalconBorda
authored
Learning Rate finder (#1347)
* initial structure * rebase * incorporate suggestions * update CHANGELOG.md * initial docs * fixes based on reviews * added trainer arg * update docs * added saving/restore of model state * initial tests * fix styling * added more tests * fix docs, backward compatility and progressbar * fix styling * docs update * updates based on review * changed saving to standard functions * consistent naming * fix formatting * improve docs, added support for nested fields, improve codecov * update CHANGELOG.md * Update lr_finder.rst * Update pytorch_lightning/trainer/trainer.py * Update trainer.py * Update CHANGELOG.md * Update path * restoring * test * attribs * docs * doc typo Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: J. Borovec <[email protected]>
1 parent d05ac81 commit 3f09b32

File tree

9 files changed

+773
-0
lines changed

9 files changed

+773
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99
### Added
1010

1111
- Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems.
12+
- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347))
1213

1314
-
1415

17 KB
Loading

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ PyTorch Lightning Documentation
6666
fast_training
6767
hooks
6868
hyperparameters
69+
lr_finder
6970
multi_gpu
7071
multiple_loaders
7172
weights_loading

docs/source/lr_finder.rst

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
Learning Rate Finder
2+
--------------------
3+
4+
For training deep neural networks, selecting a good learning rate is essential
5+
for both better performance and faster convergence. Even optimizers such as
6+
`Adam` that are self-adjusting the learning rate can benefit from more optimal
7+
choices.
8+
9+
To reduce the amount of guesswork concerning choosing a good initial learning
10+
rate, a `learning rate finder` can be used. As described in this `paper <https://arxiv.org/abs/1506.01186>`_
11+
a learning rate finder does a small run where the learning rate is increased
12+
after each processed batch and the corresponding loss is logged. The result of
13+
this is a `lr` vs. `loss` plot that can be used as guidence for choosing a optimal
14+
initial lr.
15+
16+
.. warning:: For the moment, this feature only works with models having a single optimizer.
17+
18+
Using Lightnings build-in LR finder
19+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20+
21+
In the most basic use case, this feature can be enabled during trainer construction
22+
with ``Trainer(auto_lr_find=True)``. When ``.fit(model)`` is called, the lr finder
23+
will automatically be run before any training is done. The ``lr`` that is found
24+
and used will be written to the console and logged together with all other
25+
hyperparameters of the model.
26+
27+
.. code-block:: python
28+
29+
# default, no automatic learning rate finder
30+
Trainer(auto_lr_find=True)
31+
32+
When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate.
33+
In both cases, if the respective fields are not found, an error will be thrown.
34+
35+
.. code-block:: python
36+
37+
class LitModel(LightningModule):
38+
def __init__(self, hparams):
39+
self.hparams = hparams
40+
41+
def configure_optimizers(self):
42+
return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate)
43+
44+
# finds learning rate automatically
45+
# sets hparams.lr or hparams.learning_rate to that learning rate
46+
Trainer(auto_lr_find=True)
47+
48+
To use an arbitrary value set it in the parameter.
49+
50+
.. code-block:: python
51+
52+
# to set to your own hparams.my_value
53+
Trainer(auto_lr_find='my_value')
54+
55+
Under the hood, when you call fit, this is what happens.
56+
57+
1. Run learning rate finder.
58+
2. Run actual fit.
59+
60+
.. code-block:: python
61+
62+
# when you call .fit() this happens
63+
# 1. find learning rate
64+
# 2. actually run fit
65+
trainer.fit(model)
66+
67+
If you want to inspect the results of the learning rate finder before doing any
68+
actual training or just play around with the parameters of the algorithm, this
69+
can be done by invoking the ``lr_find`` method of the trainer. A typical example
70+
of this would look like
71+
72+
.. code-block:: python
73+
74+
model = MyModelClass(hparams)
75+
trainer = pl.Trainer()
76+
77+
# Run learning rate finder
78+
lr_finder = trainer.lr_find(model)
79+
80+
# Results can be found in
81+
lr_finder.results
82+
83+
# Plot with
84+
fig = lr_finder.plot(suggest=True)
85+
fig.show()
86+
87+
# Pick point based on plot, or get suggestion
88+
new_lr = lr_finder.suggestion()
89+
90+
# update hparams of the model
91+
model.hparams.lr = new_lr
92+
93+
# Fit model
94+
trainer.fit(model)
95+
96+
The figure produced by ``lr_finder.plot()`` should look something like the figure
97+
below. It is recommended to not pick the learning rate that achives the lowest
98+
loss, but instead something in the middle of the sharpest downward slope (red point).
99+
This is the point returned py ``lr_finder.suggestion()``.
100+
101+
.. figure:: /_images/trainer/lr_finder.png
102+
103+
The parameters of the algorithm can be seen below.
104+
105+
.. autoclass:: pytorch_lightning.trainer.lr_finder.TrainerLRFinderMixin
106+
:members: lr_find
107+
:noindex:
108+
:exclude-members: _run_lr_finder_internally, save_checkpoint, restore

pytorch_lightning/trainer/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,27 @@ def forward(self, x):
135135
# default used by the Trainer
136136
trainer = Trainer(amp_level='O1')
137137
138+
auto_lr_find
139+
^^^^^^^^^^^^
140+
Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
141+
before any training, to find optimal initial learning rate.
142+
143+
.. code-block:: python
144+
145+
# default used by the Trainer (no learning rate finder)
146+
trainer = Trainer(auto_lr_find=False)
147+
148+
Example::
149+
150+
# run learning rate finder, results override hparams.learning_rate
151+
trainer = Trainer(auto_lr_find=True)
152+
153+
# run learning rate finder, results override hparams.my_lr_arg
154+
trainer = Trainer(auto_lr_find='my_lr_arg')
155+
156+
.. note::
157+
See the `learning rate finder guide <lr_finder.rst>`_
158+
138159
benchmark
139160
^^^^^^^^^
140161

0 commit comments

Comments
 (0)