Skip to content

Commit aa40e34

Browse files
committed
add callback system + associated tests
1 parent 980a0d1 commit aa40e34

File tree

6 files changed

+297
-0
lines changed

6 files changed

+297
-0
lines changed

pytorch_lightning/callbacks/base.py

+16
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ def set_trainer(self, trainer):
2727
`trainer.batch_idx`, `trainer.global_step` can be used."""
2828
self._trainer = trainer
2929

30+
def on_init_begin(self):
31+
"""Called when the trainer initialization begins."""
32+
pass
33+
34+
def on_init_end(self):
35+
"""Called when the trainer initialization ends."""
36+
pass
37+
38+
def on_fit_begin(self):
39+
"""Called when the fit begins."""
40+
pass
41+
42+
def on_fit_end(self):
43+
"""Called when the fit ends."""
44+
pass
45+
3046
def on_epoch_begin(self):
3147
"""Called when the epoch begins."""
3248
pass
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
from abc import ABC
3+
4+
from pytorch_lightning.callbacks import Callback
5+
6+
7+
class TrainerCallbackHookMixin(ABC):
8+
9+
def __init__(self):
10+
# this is just a summary on variables used in this abstract class,
11+
# the proper values/initialisation should be done in child class
12+
self.callbacks: list[Callback] = []
13+
14+
def on_init_begin(self):
15+
"""Called when the trainer initialization begins."""
16+
for callback in self.callbacks:
17+
callback.set_trainer(self)
18+
callback.on_init_begin()
19+
20+
def on_init_end(self):
21+
"""Called when the trainer initialization ends."""
22+
for callback in self.callbacks:
23+
callback.on_init_end()
24+
25+
def on_fit_begin(self):
26+
"""Called when the fit begins."""
27+
for callback in self.callbacks:
28+
callback.on_fit_begin()
29+
30+
def on_fit_end(self):
31+
"""Called when the fit ends."""
32+
for callback in self.callbacks:
33+
callback.on_fit_end()
34+
35+
def on_epoch_begin(self):
36+
"""Called when the epoch begins."""
37+
for callback in self.callbacks:
38+
callback.on_epoch_begin()
39+
40+
def on_epoch_end(self):
41+
"""Called when the epoch ends."""
42+
for callback in self.callbacks:
43+
callback.on_epoch_end()
44+
45+
def on_train_begin(self):
46+
"""Called when the train begins."""
47+
for callback in self.callbacks:
48+
callback.on_train_begin()
49+
50+
def on_train_end(self):
51+
"""Called when the train ends."""
52+
for callback in self.callbacks:
53+
callback.on_train_end()
54+
55+
def on_batch_begin(self):
56+
"""Called when the training batch begins."""
57+
for callback in self.callbacks:
58+
callback.on_batch_begin()
59+
60+
def on_batch_end(self):
61+
"""Called when the training batch ends."""
62+
for callback in self.callbacks:
63+
callback.on_batch_end()
64+
65+
def on_validation_begin(self):
66+
"""Called when the validation loop begins."""
67+
for callback in self.callbacks:
68+
callback.on_validation_begin()
69+
70+
def on_validation_end(self):
71+
"""Called when the validation loop ends."""
72+
for callback in self.callbacks:
73+
callback.on_validation_end()
74+
75+
def on_test_begin(self):
76+
"""Called when the test begins."""
77+
for callback in self.callbacks:
78+
callback.on_test_begin()
79+
80+
def on_test_end(self):
81+
"""Called when the test ends."""
82+
for callback in self.callbacks:
83+
callback.on_test_end()

pytorch_lightning/trainer/evaluation_loop.py

+18
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def __init__(self):
169169
self.get_val_dataloaders = None
170170
self.use_tpu = None
171171

172+
# Callback system
173+
self.on_validation_begin = None
174+
self.on_validation_end = None
175+
self.on_test_begin = None
176+
self.on_test_end = None
177+
172178
@abstractmethod
173179
def copy_trainer_model_properties(self, model):
174180
# this is just empty shell for code from other class
@@ -293,6 +299,12 @@ def run_evaluation(self, test=False):
293299
Please define and try again'''
294300
raise MisconfigurationException(m)
295301

302+
# Validation/Test begin callbacks
303+
if test:
304+
self.on_test_begin()
305+
else:
306+
self.on_validation_begin()
307+
296308
# hook
297309
model = self.get_model()
298310
model.on_pre_performance_check()
@@ -353,6 +365,12 @@ def run_evaluation(self, test=False):
353365
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
354366
self.checkpoint_callback.on_validation_end()
355367

368+
# Validation/Test end callbacks
369+
if test:
370+
self.on_test_end()
371+
else:
372+
self.on_validation_end()
373+
356374
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
357375
# make dataloader_idx arg in validation_step optional
358376
args = [batch, batch_idx]

pytorch_lightning/trainer/trainer.py

+32
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.trainer.training_io import TrainerIOMixin
2626
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
2727
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
28+
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
2829
from pytorch_lightning.utilities.debugging import MisconfigurationException
2930
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
3031

@@ -57,13 +58,15 @@ class Trainer(TrainerIOMixin,
5758
TrainerEvaluationLoopMixin,
5859
TrainerTrainLoopMixin,
5960
TrainerCallbackConfigMixin,
61+
TrainerCallbackHookMixin
6062
):
6163

6264
def __init__(
6365
self,
6466
logger=True,
6567
checkpoint_callback=True,
6668
early_stop_callback=None,
69+
callbacks: list = [],
6770
default_save_path=None,
6871
gradient_clip_val=0,
6972
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
@@ -163,6 +166,22 @@ def __init__(
163166
164167
trainer = Trainer(early_stop_callback=early_stop_callback)
165168
169+
callback (:class:`.Callback`): Add a list of callbacks.
170+
Example::
171+
172+
from pytorch_lightning.callbacks import Callback
173+
174+
class PrintCallback(Callback):
175+
def on_train_begin(self):
176+
print("Training is started!")
177+
178+
def on_train_end(self):
179+
print(f"Training is done. The logs are: {self.trainer.logs}")
180+
181+
# a list of callbacks
182+
callbacks = [PrintCallback()]
183+
trainer = Trainer(callbacks=callbacks)
184+
166185
default_save_path (str): Default path for logs and weights when no logger/ckpt_callback passed
167186
Example::
168187
@@ -579,6 +598,10 @@ def __init__(
579598
580599
"""
581600

601+
# Init callbacks
602+
self.callbacks = callbacks
603+
self.on_init_begin()
604+
582605
# Transfer params
583606
# Backward compatibility
584607
if nb_gpu_nodes is not None:
@@ -761,6 +784,8 @@ def __init__(
761784
use_amp = True
762785
self.init_amp(use_amp)
763786

787+
self.on_init_end()
788+
764789
@property
765790
def slurm_job_id(self):
766791
try:
@@ -890,6 +915,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
890915
_set_dataloader(model, val_dataloader, 'val_dataloader')
891916
_set_dataloader(model, test_dataloader, 'test_dataloader')
892917

918+
# Training begin callbacks
919+
self.on_fit_begin()
920+
893921
# when using multi-node or DDP within a node start each module in a separate process
894922
if self.use_ddp2:
895923
task = int(os.environ['SLURM_LOCALID'])
@@ -929,6 +957,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
929957

930958
self.run_pretrain_routine(model)
931959

960+
# Training end callbacks
961+
self.on_fit_end()
962+
932963
# return 1 when finished
933964
# used for testing or when we need to know that training succeeded
934965
return 1
@@ -1082,6 +1113,7 @@ def test(self, model=None):
10821113
trainer = Trainer()
10831114
trainer.test(model)
10841115
"""
1116+
10851117
self.testing = True
10861118
if model is not None:
10871119
self.fit(model)

pytorch_lightning/trainer/training_loop.py

+30
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,16 @@ def __init__(self):
230230
self.profiler = None
231231
self.batch_idx = None
232232
self.precision = None
233+
self.callbacks = []
234+
self.max_steps = None
235+
236+
# Callback system
237+
self.on_train_begin = None
238+
self.on_train_end = None
239+
self.on_batch_begin = None
240+
self.on_batch_end = None
241+
self.on_epoch_begin = None
242+
self.on_epoch_end = None
233243

234244
@property
235245
def max_nb_epochs(self):
@@ -305,6 +315,10 @@ def process_output(self, output, train):
305315
pass
306316

307317
def train(self):
318+
319+
# Train begin callbacks
320+
self.on_train_begin()
321+
308322
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
309323
' but will start from "0" in v0.8.0.', DeprecationWarning)
310324
model = self.get_model()
@@ -375,6 +389,7 @@ def train(self):
375389
if self.max_steps and self.max_steps == self.global_step:
376390
self.main_progress_bar.close()
377391
model.on_train_end()
392+
self.on_train_end()
378393
return
379394

380395
# early stopping
@@ -390,17 +405,23 @@ def train(self):
390405
self.main_progress_bar.close()
391406
with self.profiler.profile('on_train_end'):
392407
model.on_train_end()
408+
self.on_train_end()
393409
return
394410

395411
self.main_progress_bar.close()
396412

397413
with self.profiler.profile('on_train_end'):
398414
model.on_train_end()
415+
self.on_train_end()
399416

400417
if self.logger is not None:
401418
self.logger.finalize("success")
402419

403420
def run_training_epoch(self):
421+
422+
# Epoch begin callbacks
423+
self.on_epoch_begin()
424+
404425
# before epoch hook
405426
if self.is_function_implemented('on_epoch_start'):
406427
model = self.get_model()
@@ -486,6 +507,9 @@ def run_training_epoch(self):
486507
with self.profiler.profile('on_epoch_end'):
487508
model.on_epoch_end()
488509

510+
# Epoch begin callbacks
511+
self.on_epoch_end()
512+
489513
def run_training_batch(self, batch, batch_idx):
490514
# track grad norms
491515
grad_norm_dic = {}
@@ -499,6 +523,9 @@ def run_training_batch(self, batch, batch_idx):
499523
if batch is None:
500524
return 0, grad_norm_dic, {}
501525

526+
# Batch begin callbacks
527+
self.on_batch_begin()
528+
502529
# hook
503530
if self.is_function_implemented('on_batch_start'):
504531
model_ref = self.get_model()
@@ -610,6 +637,9 @@ def optimizer_closure():
610637
with self.profiler.profile('on_batch_end'):
611638
model.on_batch_end()
612639

640+
# Batch end callbacks
641+
self.on_batch_end()
642+
613643
# update progress bar
614644
self.main_progress_bar.update(1)
615645
self.main_progress_bar.set_postfix(**self.training_tqdm_dict)

0 commit comments

Comments
 (0)