Skip to content

Commit 68725ec

Browse files
hadimtullie
authored andcommitted
Callbacks [wip] (Lightning-AI#889)
* Add callback system + associated test * Add trainer and pl_module args to callback methods * typing * typo in docstring * Switch to on_.*_start() * fix on_test_start * fix the mess after rebasing
1 parent 2e0a199 commit 68725ec

14 files changed

+407
-87
lines changed

docs/source/callbacks.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ Callbacks
99
_save_model,
1010
on_epoch_end,
1111
on_train_end,
12-
on_epoch_begin,
12+
on_epoch_start,
1313
check_monitor_top_k,
14-
on_train_begin,
14+
on_train_start,

docs/source/loggers.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ Loggers
99
_save_model,
1010
on_epoch_end,
1111
on_train_end,
12-
on_epoch_begin,
12+
on_epoch_start,

pytorch_lightning/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929

3030
from .core import data_loader, LightningModule
3131
from .trainer import Trainer
32+
from .callbacks import Callback
3233

3334
__all__ = [
3435
'Trainer',
3536
'LightningModule',
37+
'Callback',
3638
'data_loader',
3739
]
3840
# __call__ = __all__

pytorch_lightning/callbacks/base.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -8,61 +8,61 @@
88
import abc
99

1010

11-
_NO_TRAINER_ERROR_MSG = ".set_trainer() should be called after the callback initialization"
12-
13-
1411
class Callback(abc.ABC):
1512
"""Abstract base class used to build new callbacks."""
1613

17-
def __init__(self):
18-
self._trainer = None
14+
def on_init_start(self, trainer, pl_module):
15+
"""Called when the trainer initialization begins."""
16+
assert pl_module is None
17+
18+
def on_init_end(self, trainer, pl_module):
19+
"""Called when the trainer initialization ends."""
20+
pass
1921

20-
@property
21-
def trainer(self):
22-
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
23-
return self._trainer
22+
def on_fit_start(self, trainer, pl_module):
23+
"""Called when the fit begins."""
24+
pass
2425

25-
def set_trainer(self, trainer):
26-
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
27-
`trainer.batch_idx`, `trainer.global_step` can be used."""
28-
self._trainer = trainer
26+
def on_fit_end(self, trainer, pl_module):
27+
"""Called when the fit ends."""
28+
pass
2929

30-
def on_epoch_begin(self):
30+
def on_epoch_start(self, trainer, pl_module):
3131
"""Called when the epoch begins."""
3232
pass
3333

34-
def on_epoch_end(self):
34+
def on_epoch_end(self, trainer, pl_module):
3535
"""Called when the epoch ends."""
3636
pass
3737

38-
def on_batch_begin(self):
38+
def on_batch_start(self, trainer, pl_module):
3939
"""Called when the training batch begins."""
4040
pass
4141

42-
def on_batch_end(self):
42+
def on_batch_end(self, trainer, pl_module):
4343
"""Called when the training batch ends."""
4444
pass
4545

46-
def on_train_begin(self):
46+
def on_train_start(self, trainer, pl_module):
4747
"""Called when the train begins."""
4848
pass
4949

50-
def on_train_end(self):
50+
def on_train_end(self, trainer, pl_module):
5151
"""Called when the train ends."""
5252
pass
5353

54-
def on_validation_begin(self):
54+
def on_validation_start(self, trainer, pl_module):
5555
"""Called when the validation loop begins."""
5656
pass
5757

58-
def on_validation_end(self):
58+
def on_validation_end(self, trainer, pl_module):
5959
"""Called when the validation loop ends."""
6060
pass
6161

62-
def on_test_begin(self):
62+
def on_test_start(self, trainer, pl_module):
6363
"""Called when the test begins."""
6464
pass
6565

66-
def on_test_end(self):
66+
def on_test_end(self, trainer, pl_module):
6767
"""Called when the test ends."""
6868
pass

pytorch_lightning/callbacks/early_stopping.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
6464
self.monitor_op = mode_dict[mode]
6565
self.min_delta *= 1 if self.monitor_op == np.greater else -1
6666

67-
self.on_train_begin()
67+
self.on_train_start(None, None)
6868

6969
def check_metrics(self, logs):
7070
monitor_val = logs.get(self.monitor)
@@ -82,14 +82,14 @@ def check_metrics(self, logs):
8282

8383
return True
8484

85-
def on_train_begin(self):
85+
def on_train_start(self, trainer, pl_module):
8686
# Allow instances to be re-used
8787
self.wait = 0
8888
self.stopped_epoch = 0
8989
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
9090

91-
def on_epoch_end(self):
92-
logs = self.trainer.callback_metrics
91+
def on_epoch_end(self, trainer, pl_module):
92+
logs = trainer.callback_metrics
9393
stop_training = False
9494
if not self.check_metrics(logs):
9595
return stop_training
@@ -101,13 +101,13 @@ def on_epoch_end(self):
101101
else:
102102
self.wait += 1
103103
if self.wait >= self.patience:
104-
self.stopped_epoch = self.trainer.current_epoch
104+
self.stopped_epoch = trainer.current_epoch
105105
stop_training = True
106-
self.on_train_end()
106+
self.on_train_end(trainer, pl_module)
107107

108108
return stop_training
109109

110-
def on_train_end(self):
110+
def on_train_end(self, trainer, pl_module):
111111
if self.stopped_epoch > 0 and self.verbose > 0:
112112
warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
113113
' but will start from "0" in v0.8.0.', DeprecationWarning)

pytorch_lightning/callbacks/gradient_accumulation_scheduler.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def __init__(self, scheduling: dict):
4444
self.scheduling = scheduling
4545
self.epochs = sorted(scheduling.keys())
4646

47-
def on_epoch_begin(self):
48-
trainer = self.trainer
47+
def on_epoch_start(self, trainer, pl_module):
4948
# indexing epochs from 1 (until v0.6.x)
5049
# In v0.8.0, ` + 1` should be removed.
5150
epoch = trainer.current_epoch + 1

pytorch_lightning/callbacks/model_checkpoint.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def check_monitor_top_k(self, current):
117117
return True
118118
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
119119

120-
def on_validation_end(self):
121-
logs = self.trainer.callback_metrics
122-
epoch = self.trainer.current_epoch
120+
def on_validation_end(self, trainer, pl_module):
121+
logs = trainer.callback_metrics
122+
epoch = trainer.current_epoch
123123
self.epochs_since_last_check += 1
124124

125125
if self.save_top_k == 0:

pytorch_lightning/trainer/callback_config.py

-6
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ def configure_checkpoint_callback(self):
4848
# if checkpoint callback used, then override the weights path
4949
self.weights_save_path = self.checkpoint_callback.filepath
5050

51-
# link to the trainer
52-
self.checkpoint_callback.set_trainer(self)
53-
5451
# if weights_save_path is still none here, set to current working dir
5552
if self.weights_save_path is None:
5653
self.weights_save_path = self.default_save_path
@@ -80,6 +77,3 @@ def configure_early_stopping(self, early_stop_callback):
8077
else:
8178
self.early_stop_callback = early_stop_callback
8279
self.enable_early_stop = True
83-
84-
if self.early_stop_callback is not None:
85-
self.early_stop_callback.set_trainer(self)
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import Callable
2+
from abc import ABC
3+
4+
from pytorch_lightning.callbacks import Callback
5+
6+
7+
class TrainerCallbackHookMixin(ABC):
8+
9+
def __init__(self):
10+
# this is just a summary on variables used in this abstract class,
11+
# the proper values/initialisation should be done in child class
12+
self.callbacks: list[Callback] = []
13+
self.get_model: Callable = ...
14+
15+
def on_init_start(self):
16+
"""Called when the trainer initialization begins."""
17+
for callback in self.callbacks:
18+
callback.on_init_start(self, None)
19+
20+
def on_init_end(self):
21+
"""Called when the trainer initialization ends."""
22+
for callback in self.callbacks:
23+
callback.on_init_end(self, self.get_model())
24+
25+
def on_fit_start(self):
26+
"""Called when the fit begins."""
27+
for callback in self.callbacks:
28+
callback.on_fit_start(self, self.get_model())
29+
30+
def on_fit_end(self):
31+
"""Called when the fit ends."""
32+
for callback in self.callbacks:
33+
callback.on_fit_end(self, self.get_model())
34+
35+
def on_epoch_start(self):
36+
"""Called when the epoch begins."""
37+
for callback in self.callbacks:
38+
callback.on_epoch_start(self, self.get_model())
39+
40+
def on_epoch_end(self):
41+
"""Called when the epoch ends."""
42+
for callback in self.callbacks:
43+
callback.on_epoch_end(self, self.get_model())
44+
45+
def on_train_start(self):
46+
"""Called when the train begins."""
47+
for callback in self.callbacks:
48+
callback.on_train_start(self, self.get_model())
49+
50+
def on_train_end(self):
51+
"""Called when the train ends."""
52+
for callback in self.callbacks:
53+
callback.on_train_end(self, self.get_model())
54+
55+
def on_batch_start(self):
56+
"""Called when the training batch begins."""
57+
for callback in self.callbacks:
58+
callback.on_batch_start(self, self.get_model())
59+
60+
def on_batch_end(self):
61+
"""Called when the training batch ends."""
62+
for callback in self.callbacks:
63+
callback.on_batch_end(self, self.get_model())
64+
65+
def on_validation_start(self):
66+
"""Called when the validation loop begins."""
67+
for callback in self.callbacks:
68+
callback.on_validation_start(self, self.get_model())
69+
70+
def on_validation_end(self):
71+
"""Called when the validation loop ends."""
72+
for callback in self.callbacks:
73+
callback.on_validation_end(self, self.get_model())
74+
75+
def on_test_start(self):
76+
"""Called when the test begins."""
77+
for callback in self.callbacks:
78+
callback.on_test_start(self, self.get_model())
79+
80+
def on_test_end(self):
81+
"""Called when the test ends."""
82+
for callback in self.callbacks:
83+
callback.on_test_end(self, self.get_model())

pytorch_lightning/trainer/evaluation_loop.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
124124
"""
125125

126+
from typing import Callable
127+
126128
import sys
127129
from abc import ABC, abstractmethod
128130

@@ -171,6 +173,12 @@ def __init__(self):
171173
self.reload_dataloaders_every_epoch = None
172174
self.progress_bar_refresh_rate = None
173175

176+
# Callback system
177+
self.on_validation_start: Callable = ...
178+
self.on_validation_end: Callable = ...
179+
self.on_test_start: Callable = ...
180+
self.on_test_end: Callable = ...
181+
174182
@abstractmethod
175183
def copy_trainer_model_properties(self, model):
176184
# this is just empty shell for code from other class
@@ -302,6 +310,12 @@ def run_evaluation(self, test_mode: bool = False):
302310
" Please define and try again"
303311
raise MisconfigurationException(m)
304312

313+
# Validation/Test begin callbacks
314+
if test_mode:
315+
self.on_test_start()
316+
else:
317+
self.on_validation_start()
318+
305319
# hook
306320
model = self.get_model()
307321
model.on_pre_performance_check()
@@ -363,7 +377,13 @@ def run_evaluation(self, test_mode: bool = False):
363377

364378
# model checkpointing
365379
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
366-
self.checkpoint_callback.on_validation_end()
380+
self.checkpoint_callback.on_validation_end(self, self.get_model())
381+
382+
# Validation/Test end callbacks
383+
if test_mode:
384+
self.on_test_end()
385+
else:
386+
self.on_validation_end()
367387

368388
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
369389
# make dataloader_idx arg in validation_step optional

0 commit comments

Comments
 (0)