12
12
13
13
import numpy as np
14
14
15
- from pytorch_lightning .overrides .data_parallel import LightningDistributedDataParallel
16
-
17
15
18
16
class Callback (object ):
19
- r"""Abstract base class used to build new callbacks.
20
- """
17
+ """Abstract base class used to build new callbacks."""
21
18
22
19
def __init__ (self ):
23
- self .validation_data = None
24
- self .model = None
25
-
26
- def set_params (self , params ):
27
- self .params = params
20
+ self ._trainer = None
28
21
29
- def set_model (self , model ):
30
- if isinstance ( model , LightningDistributedDataParallel ):
31
- model = model . module
32
- self .model = model
22
+ def set_trainer (self , trainer ):
23
+ """Make a link to the trainer, so different things like `trainer.current_epoch`,
24
+ `trainer.batch_idx`, `trainer.global_step` can be used."""
25
+ self ._trainer = trainer
33
26
34
- def on_epoch_begin (self , epoch , logs = None ):
35
- """
36
- called when the epoch begins
27
+ def on_epoch_begin (self ):
28
+ """Called when the epoch begins."""
29
+ pass
37
30
38
- Args :
39
- epoch (int): current epoch
40
- logs (dict): key-value pairs of quantities to monitor
31
+ def on_epoch_end ( self ) :
32
+ """Called when the epoch ends."""
33
+ pass
41
34
42
- Example:
35
+ def on_batch_begin (self ):
36
+ """Called when the training batch begins."""
37
+ pass
43
38
44
- on_epoch_begin(epoch=2, logs={'val_loss': 0.2})
45
- """
39
+ def on_batch_end (self ):
40
+ """Called when the training batch ends."""
41
+ pass
46
42
47
- def on_epoch_end (self , epoch , logs = None ):
43
+ def on_train_begin (self ):
44
+ """Called when the train begins."""
48
45
pass
49
46
50
- def on_batch_begin (self , batch , logs = None ):
51
- """
52
- called when the batch starts.
47
+ def on_train_end (self ):
48
+ """Called when the train ends."""
49
+ pass
53
50
54
- Args:
55
- batch (Tensor): current batch tensor
56
- logs (dict): key-value pairs of quantities to monitor
57
- """
51
+ def on_validation_begin (self ):
52
+ """Called when the validation loop begins."""
53
+ pass
58
54
59
- def on_batch_end (self , batch , logs = None ):
55
+ def on_validation_end (self ):
56
+ """Called when the validation loop ends."""
60
57
pass
61
58
62
- def on_train_begin (self , logs = None ):
59
+ def on_test_begin (self ):
60
+ """Called when the test begins."""
63
61
pass
64
62
65
- def on_train_end (self , logs = None ):
63
+ def on_test_end (self ):
64
+ """Called when the test ends."""
66
65
pass
67
66
68
67
68
+ _NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
69
+
70
+
69
71
class EarlyStopping (Callback ):
70
72
r"""
71
73
Stop training when a monitored quantity has stopped improving.
@@ -148,13 +150,16 @@ def check_metrics(self, logs):
148
150
149
151
return True
150
152
151
- def on_train_begin (self , logs = None ):
153
+ def on_train_begin (self ):
152
154
# Allow instances to be re-used
153
155
self .wait = 0
154
156
self .stopped_epoch = 0
155
157
self .best = np .Inf if self .monitor_op == np .less else - np .Inf
156
158
157
- def on_epoch_end (self , epoch , logs = None ):
159
+ def on_epoch_end (self ):
160
+ assert self ._trainer is not None , _NO_TRAINER_ERROR_MSG
161
+
162
+ logs = self ._trainer .callback_metrics
158
163
stop_training = False
159
164
if not self .check_metrics (logs ):
160
165
return stop_training
@@ -166,13 +171,13 @@ def on_epoch_end(self, epoch, logs=None):
166
171
else :
167
172
self .wait += 1
168
173
if self .wait >= self .patience :
169
- self .stopped_epoch = epoch
174
+ self .stopped_epoch = self . _trainer . current_epoch
170
175
stop_training = True
171
176
self .on_train_end ()
172
177
173
178
return stop_training
174
179
175
- def on_train_end (self , logs = None ):
180
+ def on_train_end (self ):
176
181
if self .stopped_epoch > 0 and self .verbose > 0 :
177
182
warnings .warn ('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
178
183
' but will start from "0" in v0.8.0.' , DeprecationWarning )
@@ -306,8 +311,11 @@ def check_monitor_top_k(self, current):
306
311
return True
307
312
return self .monitor_op (current , self .best_k_models [self .kth_best_model ])
308
313
309
- def on_epoch_end (self , epoch , logs = None ):
310
- logs = logs or {}
314
+ def on_validation_end (self ):
315
+ assert self ._trainer is not None , _NO_TRAINER_ERROR_MSG
316
+
317
+ logs = self ._trainer .callback_metrics
318
+ epoch = self ._trainer .current_epoch
311
319
self .epochs_since_last_check += 1
312
320
313
321
if self .save_top_k == 0 :
@@ -389,6 +397,8 @@ class GradientAccumulationScheduler(Callback):
389
397
"""
390
398
391
399
def __init__ (self , scheduling : dict ):
400
+ super ().__init__ ()
401
+
392
402
if scheduling == {}: # empty dict error
393
403
raise TypeError ("Empty dict cannot be interpreted correct" )
394
404
@@ -408,21 +418,14 @@ def __init__(self, scheduling: dict):
408
418
self .scheduling = scheduling
409
419
self .epochs = sorted (scheduling .keys ())
410
420
411
- def on_epoch_begin (self , epoch , trainer ):
421
+ def on_epoch_begin (self ):
422
+ assert self ._trainer is not None , _NO_TRAINER_ERROR_MSG
423
+
424
+ trainer = self ._trainer
412
425
# indexing epochs from 1 (until v0.6.x)
413
- # In v0.8.0, `epoch += 1` should be removed.
414
- epoch += 1
426
+ # In v0.8.0, ` + 1` should be removed.
427
+ epoch = trainer . current_epoch + 1
415
428
for i in reversed (range (len (self .epochs ))):
416
429
if epoch >= self .epochs [i ]:
417
430
trainer .accumulate_grad_batches = self .scheduling .get (self .epochs [i ])
418
431
break
419
-
420
-
421
- # if __name__ == '__main__':
422
- # c = EarlyStopping(min_delta=0.9, patience=2, verbose=True)
423
- # losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5]
424
- # for i, loss in enumerate(losses):
425
- # should_stop = c.on_epoch_end(i, logs={'val_loss': loss})
426
- # log.info(loss)
427
- # if should_stop:
428
- # break
0 commit comments