Skip to content

Commit c2576c2

Browse files
committed
Add callback system + associated test
1 parent 89d5772 commit c2576c2

File tree

7 files changed

+321
-7
lines changed

7 files changed

+321
-7
lines changed

pytorch_lightning/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929

3030
from .core import data_loader, LightningModule
3131
from .trainer import Trainer
32+
from .callbacks import Callback
3233

3334
__all__ = [
3435
'Trainer',
3536
'LightningModule',
37+
'Callback',
3638
'data_loader',
3739
]
3840
# __call__ = __all__

pytorch_lightning/callbacks/base.py

+28
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,39 @@ def trainer(self):
2222
assert self._trainer is not None, _NO_TRAINER_ERROR_MSG
2323
return self._trainer
2424

25+
@property
26+
def default_save_path(self):
27+
"""Trainer default save path.
28+
"""
29+
return self._trainer.default_save_path
30+
31+
@property
32+
def rank(self):
33+
"""Current trainer rank.
34+
"""
35+
return self._trainer.proc_rank
36+
2537
def set_trainer(self, trainer):
2638
"""Make a link to the trainer, so different things like `trainer.current_epoch`,
2739
`trainer.batch_idx`, `trainer.global_step` can be used."""
2840
self._trainer = trainer
2941

42+
def on_init_begin(self):
43+
"""Called when the trainer initialization begins."""
44+
pass
45+
46+
def on_init_end(self):
47+
"""Called when the trainer initialization ends."""
48+
pass
49+
50+
def on_fit_begin(self):
51+
"""Called when the fit begins."""
52+
pass
53+
54+
def on_fit_end(self):
55+
"""Called when the fit ends."""
56+
pass
57+
3058
def on_epoch_begin(self):
3159
"""Called when the epoch begins."""
3260
pass
+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from abc import ABC
2+
3+
from pytorch_lightning.callbacks import Callback
4+
5+
6+
class TrainerCallbackHookMixin(ABC):
7+
8+
def __init__(self):
9+
# this is just a summary on variables used in this abstract class,
10+
# the proper values/initialisation should be done in child class
11+
self.callbacks: list[Callback] = []
12+
13+
def on_init_begin(self):
14+
"""Called when the trainer initialization begins."""
15+
for callback in self.callbacks:
16+
callback.set_trainer(self)
17+
callback.on_init_begin()
18+
19+
def on_init_end(self):
20+
"""Called when the trainer initialization ends."""
21+
for callback in self.callbacks:
22+
callback.on_init_end()
23+
24+
def on_fit_begin(self):
25+
"""Called when the fit begins."""
26+
for callback in self.callbacks:
27+
callback.on_fit_begin()
28+
29+
def on_fit_end(self):
30+
"""Called when the fit ends."""
31+
for callback in self.callbacks:
32+
callback.on_fit_end()
33+
34+
def on_epoch_begin(self):
35+
"""Called when the epoch begins."""
36+
for callback in self.callbacks:
37+
callback.on_epoch_begin()
38+
39+
def on_epoch_end(self):
40+
"""Called when the epoch ends."""
41+
for callback in self.callbacks:
42+
callback.on_epoch_end()
43+
44+
def on_train_begin(self):
45+
"""Called when the train begins."""
46+
for callback in self.callbacks:
47+
callback.on_train_begin()
48+
49+
def on_train_end(self):
50+
"""Called when the train ends."""
51+
for callback in self.callbacks:
52+
callback.on_train_end()
53+
54+
def on_batch_begin(self):
55+
"""Called when the training batch begins."""
56+
for callback in self.callbacks:
57+
callback.on_batch_begin()
58+
59+
def on_batch_end(self):
60+
"""Called when the training batch ends."""
61+
for callback in self.callbacks:
62+
callback.on_batch_end()
63+
64+
def on_validation_begin(self):
65+
"""Called when the validation loop begins."""
66+
for callback in self.callbacks:
67+
callback.on_validation_begin()
68+
69+
def on_validation_end(self):
70+
"""Called when the validation loop ends."""
71+
for callback in self.callbacks:
72+
callback.on_validation_end()
73+
74+
def on_test_begin(self):
75+
"""Called when the test begins."""
76+
for callback in self.callbacks:
77+
callback.on_test_begin()
78+
79+
def on_test_end(self):
80+
"""Called when the test ends."""
81+
for callback in self.callbacks:
82+
callback.on_test_end()

pytorch_lightning/trainer/evaluation_loop.py

+20
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@
123123
124124
"""
125125

126+
from typing import Callable
127+
126128
import sys
127129
from abc import ABC, abstractmethod
128130

@@ -169,6 +171,12 @@ def __init__(self):
169171
self.get_val_dataloaders = None
170172
self.use_tpu = None
171173

174+
# Callback system
175+
self.on_validation_begin: Callable = None
176+
self.on_validation_end: Callable = None
177+
self.on_test_begin: Callable = None
178+
self.on_test_end: Callable = None
179+
172180
@abstractmethod
173181
def copy_trainer_model_properties(self, model):
174182
# this is just empty shell for code from other class
@@ -293,6 +301,12 @@ def run_evaluation(self, test=False):
293301
Please define and try again'''
294302
raise MisconfigurationException(m)
295303

304+
# Validation/Test begin callbacks
305+
if test:
306+
self.on_test_begin()
307+
else:
308+
self.on_validation_begin()
309+
296310
# hook
297311
model = self.get_model()
298312
model.on_pre_performance_check()
@@ -353,6 +367,12 @@ def run_evaluation(self, test=False):
353367
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
354368
self.checkpoint_callback.on_validation_end()
355369

370+
# Validation/Test end callbacks
371+
if test:
372+
self.on_test_end()
373+
else:
374+
self.on_validation_end()
375+
356376
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
357377
# make dataloader_idx arg in validation_step optional
358378
args = [batch, batch_idx]

pytorch_lightning/trainer/trainer.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
from pytorch_lightning.trainer.training_io import TrainerIOMixin
3131
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
3232
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
33+
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
3334
from pytorch_lightning.utilities.debugging import MisconfigurationException
3435
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
36+
from pytorch_lightning.callbacks import Callback
3537

3638

3739
try:
@@ -62,13 +64,15 @@ class Trainer(TrainerIOMixin,
6264
TrainerEvaluationLoopMixin,
6365
TrainerTrainLoopMixin,
6466
TrainerCallbackConfigMixin,
67+
TrainerCallbackHookMixin
6568
):
6669

6770
def __init__(
6871
self,
6972
logger: Union[LightningLoggerBase, bool] = True,
7073
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
7174
early_stop_callback: Optional[Union[EarlyStopping, bool]] = None,
75+
callbacks: list = [],
7276
default_save_path: Optional[str] = None,
7377
gradient_clip_val: float = 0,
7478
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
@@ -168,6 +172,18 @@ def __init__(
168172
169173
trainer = Trainer(early_stop_callback=early_stop_callback)
170174
175+
callback (:class:`.Callback`): Add a list of callbacks.
176+
Example::
177+
from pytorch_lightning.callbacks import Callback
178+
class PrintCallback(Callback):
179+
def on_train_begin(self):
180+
print("Training is started!")
181+
def on_train_end(self):
182+
print(f"Training is done. The logs are: {self.trainer.logs}")
183+
# a list of callbacks
184+
callbacks = [PrintCallback()]
185+
trainer = Trainer(callbacks=callbacks)
186+
171187
default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
172188
Example::
173189
@@ -584,6 +600,10 @@ def __init__(
584600
585601
"""
586602

603+
# Init callbacks
604+
self.callbacks = callbacks
605+
self.on_init_begin()
606+
587607
# Transfer params
588608
# Backward compatibility
589609
if nb_gpu_nodes is not None:
@@ -766,6 +786,9 @@ def __init__(
766786
use_amp = True
767787
self.init_amp(use_amp)
768788

789+
# Callback system
790+
self.on_init_end()
791+
769792
@property
770793
def slurm_job_id(self) -> int:
771794
try:
@@ -901,6 +924,9 @@ def fit(
901924
_set_dataloader(model, val_dataloader, 'val_dataloader')
902925
_set_dataloader(model, test_dataloader, 'test_dataloader')
903926

927+
# Fit begin callbacks
928+
self.on_fit_begin()
929+
904930
# when using multi-node or DDP within a node start each module in a separate process
905931
if self.use_ddp2:
906932
task = int(os.environ['SLURM_LOCALID'])
@@ -940,6 +966,9 @@ def fit(
940966

941967
self.run_pretrain_routine(model)
942968

969+
# Fit end callbacks
970+
self.on_fit_end()
971+
943972
# return 1 when finished
944973
# used for testing or when we need to know that training succeeded
945974
return 1
@@ -1034,8 +1063,8 @@ def run_pretrain_routine(self, model: LightningModule):
10341063
return
10351064

10361065
# check if we should run validation during training
1037-
self.disable_validation = ((self.num_val_batches == 0 or
1038-
not self.is_overriden('validation_step')) and
1066+
self.disable_validation = ((self.num_val_batches == 0
1067+
or not self.is_overriden('validation_step')) and
10391068
not self.fast_dev_run)
10401069

10411070
# run tiny validation (if validation defined)
@@ -1139,7 +1168,7 @@ def _set_dataloader(model, dataloader, attribute):
11391168
if is_dataloader or is_dataloader_list and valid_loaders:
11401169

11411170
# Overwrite abstract methods
1142-
dl = lambda: dataloader
1171+
def dl(): return dataloader
11431172
dl.__name__ = attribute
11441173
setattr(model, attribute, dl)
11451174

0 commit comments

Comments
 (0)