15
15
from pytorch_lightning .callbacks import Callback
16
16
from pytorch_lightning import _logger as log
17
17
from pytorch_lightning .utilities .exceptions import MisconfigurationException
18
+ from pytorch_lightning .utilities import rank_zero_warn
18
19
19
20
20
21
class TrainerLRFinderMixin (ABC ):
@@ -58,7 +59,8 @@ def lr_find(self,
58
59
max_lr : float = 1 ,
59
60
num_training : int = 100 ,
60
61
mode : str = 'exponential' ,
61
- num_accumulation_steps : int = 1 ):
62
+ early_stop_threshold : float = 4.0 ,
63
+ num_accumulation_steps = None ):
62
64
r"""
63
65
lr_find enables the user to do a range test of good initial learning rates,
64
66
to reduce the amount of guesswork in picking a good starting learning rate.
@@ -81,7 +83,12 @@ def lr_find(self,
81
83
after each batch. If set to 'exponential', will increase learning
82
84
rate exponentially.
83
85
84
- num_accumulation_steps: number of batches to calculate loss over.
86
+ early_stop_threshold: threshold for stopping the search. If the
87
+ loss at any point is larger than early_stop_threshold*best_loss
88
+ then the search is stopped. To disable, set to None.
89
+
90
+ num_accumulation_steps: deprepecated, number of batches to calculate loss over.
91
+ Set trainer argument ``accumulate_grad_batches`` instead.
85
92
86
93
Example::
87
94
@@ -104,6 +111,12 @@ def lr_find(self,
104
111
trainer.fit(model)
105
112
106
113
"""
114
+ if num_accumulation_steps is not None :
115
+ rank_zero_warn ("Argument `num_accumulation_steps` has been deprepecated"
116
+ " since v0.7.6 and will be removed in 0.9. Please"
117
+ " set trainer argument `accumulate_grad_batches` instead." ,
118
+ DeprecationWarning )
119
+
107
120
save_path = os .path .join (self .default_root_dir , 'lr_find_temp.ckpt' )
108
121
109
122
self .__lr_finder_dump_params (model )
@@ -115,7 +128,9 @@ def lr_find(self,
115
128
lr_finder = _LRFinder (mode , min_lr , max_lr , num_training )
116
129
117
130
# Use special lr logger callback
118
- self .callbacks = [_LRCallback (num_training , progress_bar_refresh_rate = 1 )]
131
+ self .callbacks = [_LRCallback (num_training ,
132
+ early_stop_threshold ,
133
+ progress_bar_refresh_rate = 1 )]
119
134
120
135
# No logging
121
136
self .logger = None
@@ -127,9 +142,6 @@ def lr_find(self,
127
142
if self .progress_bar_callback :
128
143
self .progress_bar_callback .disable ()
129
144
130
- # Accumulation of gradients
131
- self .accumulate_grad_batches = num_accumulation_steps
132
-
133
145
# Disable standard checkpoint & early stopping
134
146
self .checkpoint_callback = False
135
147
self .early_stop_callback = None
@@ -149,7 +161,6 @@ def lr_find(self,
149
161
raise MisconfigurationException (
150
162
f'`model.configure_optimizers()` returned { len (optimizers )} , but'
151
163
' learning rate finder only works with single optimizer' )
152
- configure_optimizers = model .configure_optimizers
153
164
model .configure_optimizers = lr_finder ._get_new_optimizer (optimizers [0 ])
154
165
155
166
# Fit, lr & loss logged in callback
@@ -164,6 +175,7 @@ def lr_find(self,
164
175
# Transfer results from callback to lr finder object
165
176
lr_finder .results .update ({'lr' : self .callbacks [0 ].lrs ,
166
177
'loss' : self .callbacks [0 ].losses })
178
+ lr_finder ._total_batch_idx = self .total_batch_idx # for debug purpose
167
179
168
180
# Reset model state
169
181
self .restore (str (save_path ), on_gpu = self .on_gpu )
@@ -184,7 +196,6 @@ def __lr_finder_dump_params(self, model):
184
196
'logger' : self .logger ,
185
197
'max_steps' : self .max_steps ,
186
198
'progress_bar_refresh_rate' : self .progress_bar_refresh_rate ,
187
- 'accumulate_grad_batches' : self .accumulate_grad_batches ,
188
199
'checkpoint_callback' : self .checkpoint_callback ,
189
200
'early_stop_callback' : self .early_stop_callback ,
190
201
'enable_early_stop' : self .enable_early_stop ,
@@ -198,7 +209,6 @@ def __lr_finder_restore_params(self, model):
198
209
self .callbacks = self .__dumped_params ['callbacks' ]
199
210
self .max_steps = self .__dumped_params ['max_steps' ]
200
211
self .progress_bar_refresh_rate = self .__dumped_params ['progress_bar_refresh_rate' ]
201
- self .accumulate_grad_batches = self .__dumped_params ['accumulate_grad_batches' ]
202
212
self .checkpoint_callback = self .__dumped_params ['checkpoint_callback' ]
203
213
self .early_stop_callback = self .__dumped_params ['early_stop_callback' ]
204
214
self .enable_early_stop = self .__dumped_params ['enable_early_stop' ]
@@ -242,6 +252,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
242
252
self .num_training = num_training
243
253
244
254
self .results = {}
255
+ self ._total_batch_idx = 0 # for debug purpose
245
256
246
257
def _get_new_optimizer (self , optimizer : torch .optim .Optimizer ):
247
258
""" Construct a new `configure_optimizers()` method, that has a optimizer
@@ -298,30 +309,49 @@ def plot(self, suggest: bool = False, show: bool = False):
298
309
299
310
return fig
300
311
301
- def suggestion (self ):
312
+ def suggestion (self , skip_begin : int = 10 , skip_end : int = 1 ):
302
313
""" This will propose a suggestion for choice of initial learning rate
303
314
as the point with the steepest negative gradient.
304
315
305
316
Returns:
306
317
lr: suggested initial learning rate to use
318
+ skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
319
+ skip_end: how many samples to skip in the end. Prevent too optimistic estimates
307
320
308
321
"""
309
322
try :
310
- min_grad = (np .gradient (np .array (self .results ["loss" ]))).argmin ()
311
- self ._optimal_idx = min_grad
312
- return self .results ["lr" ][min_grad ]
323
+ loss = self .results ["loss" ][skip_begin :- skip_end ]
324
+ min_grad = (np .gradient (np .array (loss ))).argmin ()
325
+ self ._optimal_idx = min_grad + skip_begin
326
+ return self .results ["lr" ][self ._optimal_idx ]
313
327
except Exception :
314
- log .warning ('Failed to compute suggesting for `lr`.'
315
- ' There might not be enough points.' )
328
+ log .exception ('Failed to compute suggesting for `lr`. There might not be enough points.' )
316
329
self ._optimal_idx = None
317
330
318
331
319
332
class _LRCallback (Callback ):
320
333
""" Special callback used by the learning rate finder. This callbacks log
321
334
the learning rate before each batch and log the corresponding loss after
322
- each batch. """
323
- def __init__ (self , num_training : int , progress_bar_refresh_rate : bool = False , beta : float = 0.98 ):
335
+ each batch.
336
+
337
+ Args:
338
+ num_training: number of iterations done by the learning rate finder
339
+ early_stop_threshold: threshold for stopping the search. If the
340
+ loss at any point is larger than ``early_stop_threshold*best_loss``
341
+ then the search is stopped. To disable, set to ``None``.
342
+ progress_bar_refresh_rate: rate to refresh the progress bar for
343
+ the learning rate finder
344
+ beta: smoothing value, the loss being logged is a running average of
345
+ loss values logged until now. ``beta`` controls the forget rate i.e.
346
+ if ``beta=0`` all past information is ignored.
347
+
348
+ """
349
+ def __init__ (self , num_training : int ,
350
+ early_stop_threshold : float = 4.0 ,
351
+ progress_bar_refresh_rate : bool = False ,
352
+ beta : float = 0.98 ):
324
353
self .num_training = num_training
354
+ self .early_stop_threshold = early_stop_threshold
325
355
self .beta = beta
326
356
self .losses = []
327
357
self .lrs = []
@@ -332,13 +362,19 @@ def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, b
332
362
333
363
def on_batch_start (self , trainer , pl_module ):
334
364
""" Called before each training batch, logs the lr that will be used """
365
+ if (trainer .batch_idx + 1 ) % trainer .accumulate_grad_batches != 0 :
366
+ return
367
+
335
368
if self .progress_bar_refresh_rate and self .progress_bar is None :
336
369
self .progress_bar = tqdm (desc = 'Finding best initial lr' , total = self .num_training )
337
370
338
371
self .lrs .append (trainer .lr_schedulers [0 ]['scheduler' ].lr [0 ])
339
372
340
373
def on_batch_end (self , trainer , pl_module ):
341
374
""" Called when the training batch ends, logs the calculated loss """
375
+ if (trainer .batch_idx + 1 ) % trainer .accumulate_grad_batches != 0 :
376
+ return
377
+
342
378
if self .progress_bar :
343
379
self .progress_bar .update ()
344
380
@@ -350,10 +386,11 @@ def on_batch_end(self, trainer, pl_module):
350
386
smoothed_loss = self .avg_loss / (1 - self .beta ** current_step )
351
387
352
388
# Check if we diverging
353
- if current_step > 1 and smoothed_loss > 4 * self .best_loss :
354
- trainer .max_steps = current_step # stop signal
355
- if self .progress_bar :
356
- self .progress_bar .close ()
389
+ if self .early_stop_threshold is not None :
390
+ if current_step > 1 and smoothed_loss > self .early_stop_threshold * self .best_loss :
391
+ trainer .max_steps = current_step # stop signal
392
+ if self .progress_bar :
393
+ self .progress_bar .close ()
357
394
358
395
# Save best loss for diverging checking
359
396
if smoothed_loss < self .best_loss or current_step == 1 :
0 commit comments