Skip to content

Commit 663b900

Browse files
SkafteNickiNicki SkafteBorda
authored
Bugfix: accumulation and suggestion for learning rate finder (#1801)
* fix suggestion being too naive * fix accumulation error and added new tests * fix styling * update CHANGELOG.md * update based on review * fix tests * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Nicki Skafte <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent aefc531 commit 663b900

File tree

3 files changed

+110
-28
lines changed

3 files changed

+110
-28
lines changed

CHANGELOG.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7070

7171
- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561))
7272

73-
- Fixed missing profiler attribute in add_argparse_args() ArgumentParser ([#1794](https://github.com/PyTorchLightning/pytorch-lightning/pull/1794))
74-
73+
- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801))
7574

7675
## [0.7.5] - 2020-04-27
7776

pytorch_lightning/trainer/lr_finder.py

+58-21
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytorch_lightning.callbacks import Callback
1616
from pytorch_lightning import _logger as log
1717
from pytorch_lightning.utilities.exceptions import MisconfigurationException
18+
from pytorch_lightning.utilities import rank_zero_warn
1819

1920

2021
class TrainerLRFinderMixin(ABC):
@@ -58,7 +59,8 @@ def lr_find(self,
5859
max_lr: float = 1,
5960
num_training: int = 100,
6061
mode: str = 'exponential',
61-
num_accumulation_steps: int = 1):
62+
early_stop_threshold: float = 4.0,
63+
num_accumulation_steps=None):
6264
r"""
6365
lr_find enables the user to do a range test of good initial learning rates,
6466
to reduce the amount of guesswork in picking a good starting learning rate.
@@ -81,7 +83,12 @@ def lr_find(self,
8183
after each batch. If set to 'exponential', will increase learning
8284
rate exponentially.
8385
84-
num_accumulation_steps: number of batches to calculate loss over.
86+
early_stop_threshold: threshold for stopping the search. If the
87+
loss at any point is larger than early_stop_threshold*best_loss
88+
then the search is stopped. To disable, set to None.
89+
90+
num_accumulation_steps: deprepecated, number of batches to calculate loss over.
91+
Set trainer argument ``accumulate_grad_batches`` instead.
8592
8693
Example::
8794
@@ -104,6 +111,12 @@ def lr_find(self,
104111
trainer.fit(model)
105112
106113
"""
114+
if num_accumulation_steps is not None:
115+
rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
116+
" since v0.7.6 and will be removed in 0.9. Please"
117+
" set trainer argument `accumulate_grad_batches` instead.",
118+
DeprecationWarning)
119+
107120
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')
108121

109122
self.__lr_finder_dump_params(model)
@@ -115,7 +128,9 @@ def lr_find(self,
115128
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
116129

117130
# Use special lr logger callback
118-
self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)]
131+
self.callbacks = [_LRCallback(num_training,
132+
early_stop_threshold,
133+
progress_bar_refresh_rate=1)]
119134

120135
# No logging
121136
self.logger = None
@@ -127,9 +142,6 @@ def lr_find(self,
127142
if self.progress_bar_callback:
128143
self.progress_bar_callback.disable()
129144

130-
# Accumulation of gradients
131-
self.accumulate_grad_batches = num_accumulation_steps
132-
133145
# Disable standard checkpoint & early stopping
134146
self.checkpoint_callback = False
135147
self.early_stop_callback = None
@@ -149,7 +161,6 @@ def lr_find(self,
149161
raise MisconfigurationException(
150162
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
151163
' learning rate finder only works with single optimizer')
152-
configure_optimizers = model.configure_optimizers
153164
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])
154165

155166
# Fit, lr & loss logged in callback
@@ -164,6 +175,7 @@ def lr_find(self,
164175
# Transfer results from callback to lr finder object
165176
lr_finder.results.update({'lr': self.callbacks[0].lrs,
166177
'loss': self.callbacks[0].losses})
178+
lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose
167179

168180
# Reset model state
169181
self.restore(str(save_path), on_gpu=self.on_gpu)
@@ -184,7 +196,6 @@ def __lr_finder_dump_params(self, model):
184196
'logger': self.logger,
185197
'max_steps': self.max_steps,
186198
'progress_bar_refresh_rate': self.progress_bar_refresh_rate,
187-
'accumulate_grad_batches': self.accumulate_grad_batches,
188199
'checkpoint_callback': self.checkpoint_callback,
189200
'early_stop_callback': self.early_stop_callback,
190201
'enable_early_stop': self.enable_early_stop,
@@ -198,7 +209,6 @@ def __lr_finder_restore_params(self, model):
198209
self.callbacks = self.__dumped_params['callbacks']
199210
self.max_steps = self.__dumped_params['max_steps']
200211
self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate']
201-
self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches']
202212
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
203213
self.early_stop_callback = self.__dumped_params['early_stop_callback']
204214
self.enable_early_stop = self.__dumped_params['enable_early_stop']
@@ -242,6 +252,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
242252
self.num_training = num_training
243253

244254
self.results = {}
255+
self._total_batch_idx = 0 # for debug purpose
245256

246257
def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
247258
""" Construct a new `configure_optimizers()` method, that has a optimizer
@@ -298,30 +309,49 @@ def plot(self, suggest: bool = False, show: bool = False):
298309

299310
return fig
300311

301-
def suggestion(self):
312+
def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
302313
""" This will propose a suggestion for choice of initial learning rate
303314
as the point with the steepest negative gradient.
304315
305316
Returns:
306317
lr: suggested initial learning rate to use
318+
skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
319+
skip_end: how many samples to skip in the end. Prevent too optimistic estimates
307320
308321
"""
309322
try:
310-
min_grad = (np.gradient(np.array(self.results["loss"]))).argmin()
311-
self._optimal_idx = min_grad
312-
return self.results["lr"][min_grad]
323+
loss = self.results["loss"][skip_begin:-skip_end]
324+
min_grad = (np.gradient(np.array(loss))).argmin()
325+
self._optimal_idx = min_grad + skip_begin
326+
return self.results["lr"][self._optimal_idx]
313327
except Exception:
314-
log.warning('Failed to compute suggesting for `lr`.'
315-
' There might not be enough points.')
328+
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
316329
self._optimal_idx = None
317330

318331

319332
class _LRCallback(Callback):
320333
""" Special callback used by the learning rate finder. This callbacks log
321334
the learning rate before each batch and log the corresponding loss after
322-
each batch. """
323-
def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98):
335+
each batch.
336+
337+
Args:
338+
num_training: number of iterations done by the learning rate finder
339+
early_stop_threshold: threshold for stopping the search. If the
340+
loss at any point is larger than ``early_stop_threshold*best_loss``
341+
then the search is stopped. To disable, set to ``None``.
342+
progress_bar_refresh_rate: rate to refresh the progress bar for
343+
the learning rate finder
344+
beta: smoothing value, the loss being logged is a running average of
345+
loss values logged until now. ``beta`` controls the forget rate i.e.
346+
if ``beta=0`` all past information is ignored.
347+
348+
"""
349+
def __init__(self, num_training: int,
350+
early_stop_threshold: float = 4.0,
351+
progress_bar_refresh_rate: bool = False,
352+
beta: float = 0.98):
324353
self.num_training = num_training
354+
self.early_stop_threshold = early_stop_threshold
325355
self.beta = beta
326356
self.losses = []
327357
self.lrs = []
@@ -332,13 +362,19 @@ def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, b
332362

333363
def on_batch_start(self, trainer, pl_module):
334364
""" Called before each training batch, logs the lr that will be used """
365+
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
366+
return
367+
335368
if self.progress_bar_refresh_rate and self.progress_bar is None:
336369
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)
337370

338371
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
339372

340373
def on_batch_end(self, trainer, pl_module):
341374
""" Called when the training batch ends, logs the calculated loss """
375+
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
376+
return
377+
342378
if self.progress_bar:
343379
self.progress_bar.update()
344380

@@ -350,10 +386,11 @@ def on_batch_end(self, trainer, pl_module):
350386
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)
351387

352388
# Check if we diverging
353-
if current_step > 1 and smoothed_loss > 4 * self.best_loss:
354-
trainer.max_steps = current_step # stop signal
355-
if self.progress_bar:
356-
self.progress_bar.close()
389+
if self.early_stop_threshold is not None:
390+
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
391+
trainer.max_steps = current_step # stop signal
392+
if self.progress_bar:
393+
self.progress_bar.close()
357394

358395
# Save best loss for diverging checking
359396
if smoothed_loss < self.best_loss or current_step == 1:

tests/trainer/test_lr_finder.py

+51-5
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def test_trainer_reset_correctly(tmpdir):
7676

7777

7878
def test_trainer_arg_bool(tmpdir):
79-
79+
""" Test that setting trainer arg to bool works """
8080
hparams = EvalModelTemplate.get_default_hparams()
8181
model = EvalModelTemplate(hparams)
8282
before_lr = hparams.learning_rate
8383

8484
# logger file to get meta
8585
trainer = Trainer(
8686
default_save_path=tmpdir,
87-
max_epochs=1,
87+
max_epochs=5,
8888
auto_lr_find=True
8989
)
9090

@@ -95,7 +95,7 @@ def test_trainer_arg_bool(tmpdir):
9595

9696

9797
def test_trainer_arg_str(tmpdir):
98-
98+
""" Test that setting trainer arg to string works """
9999
hparams = EvalModelTemplate.get_default_hparams()
100100
hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field
101101
model = EvalModelTemplate(hparams)
@@ -104,7 +104,7 @@ def test_trainer_arg_str(tmpdir):
104104
# logger file to get meta
105105
trainer = Trainer(
106106
default_save_path=tmpdir,
107-
max_epochs=1,
107+
max_epochs=5,
108108
auto_lr_find='my_fancy_lr'
109109
)
110110

@@ -115,6 +115,7 @@ def test_trainer_arg_str(tmpdir):
115115

116116

117117
def test_call_to_trainer_method(tmpdir):
118+
""" Test that directly calling the trainer method works """
118119

119120
hparams = EvalModelTemplate.get_default_hparams()
120121
model = EvalModelTemplate(hparams)
@@ -123,7 +124,7 @@ def test_call_to_trainer_method(tmpdir):
123124
# logger file to get meta
124125
trainer = Trainer(
125126
default_save_path=tmpdir,
126-
max_epochs=1,
127+
max_epochs=5,
127128
)
128129

129130
lrfinder = trainer.lr_find(model, mode='linear')
@@ -133,3 +134,48 @@ def test_call_to_trainer_method(tmpdir):
133134

134135
assert before_lr != after_lr, \
135136
'Learning rate was not altered after running learning rate finder'
137+
138+
139+
def test_accumulation_and_early_stopping(tmpdir):
140+
""" Test that early stopping of learning rate finder works, and that
141+
accumulation also works for this feature """
142+
143+
hparams = EvalModelTemplate.get_default_hparams()
144+
model = EvalModelTemplate(hparams)
145+
146+
before_lr = hparams.learning_rate
147+
# logger file to get meta
148+
trainer = Trainer(
149+
default_save_path=tmpdir,
150+
accumulate_grad_batches=2
151+
)
152+
153+
lrfinder = trainer.lr_find(model, early_stop_threshold=None)
154+
after_lr = lrfinder.suggestion()
155+
156+
assert before_lr != after_lr, \
157+
'Learning rate was not altered after running learning rate finder'
158+
assert len(lrfinder.results['lr']) == 100, \
159+
'Early stopping for learning rate finder did not work'
160+
assert lrfinder._total_batch_idx == 100 * 2, \
161+
'Accumulation parameter did not work'
162+
163+
164+
def test_suggestion_parameters_work(tmpdir):
165+
""" Test that default skipping does not alter results in basic case """
166+
167+
hparams = EvalModelTemplate.get_default_hparams()
168+
model = EvalModelTemplate(hparams)
169+
170+
# logger file to get meta
171+
trainer = Trainer(
172+
default_save_path=tmpdir,
173+
max_epochs=10,
174+
)
175+
176+
lrfinder = trainer.lr_find(model)
177+
lr1 = lrfinder.suggestion(skip_begin=10) # default
178+
lr2 = lrfinder.suggestion(skip_begin=80) # way too high, should have an impact
179+
180+
assert lr1 != lr2, \
181+
'Skipping parameter did not influence learning rate'

0 commit comments

Comments
 (0)