|
25 | 25 | from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
26 | 26 | from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
27 | 27 | from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
| 28 | +from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin |
28 | 29 | from pytorch_lightning.utilities.debugging import MisconfigurationException
|
29 | 30 | from pytorch_lightning.profiler import Profiler, PassThroughProfiler
|
30 | 31 |
|
@@ -57,13 +58,15 @@ class Trainer(TrainerIOMixin,
|
57 | 58 | TrainerEvaluationLoopMixin,
|
58 | 59 | TrainerTrainLoopMixin,
|
59 | 60 | TrainerCallbackConfigMixin,
|
| 61 | + TrainerCallbackHookMixin |
60 | 62 | ):
|
61 | 63 |
|
62 | 64 | def __init__(
|
63 | 65 | self,
|
64 | 66 | logger=True,
|
65 | 67 | checkpoint_callback=True,
|
66 | 68 | early_stop_callback=None,
|
| 69 | + callbacks: list = [], |
67 | 70 | default_save_path=None,
|
68 | 71 | gradient_clip_val=0,
|
69 | 72 | gradient_clip=None, # backward compatible, todo: remove in v0.8.0
|
@@ -163,6 +166,22 @@ def __init__(
|
163 | 166 |
|
164 | 167 | trainer = Trainer(early_stop_callback=early_stop_callback)
|
165 | 168 |
|
| 169 | + callback (:class:`.Callback`): Add a list of callbacks. |
| 170 | + Example:: |
| 171 | +
|
| 172 | + from pytorch_lightning.callbacks import Callback |
| 173 | +
|
| 174 | + class PrintCallback(Callback): |
| 175 | + def on_train_begin(self): |
| 176 | + print("Training is started!") |
| 177 | +
|
| 178 | + def on_train_end(self): |
| 179 | + print(f"Training is done. The logs are: {self.trainer.logs}") |
| 180 | +
|
| 181 | + # a list of callbacks |
| 182 | + callbacks = [PrintCallback()] |
| 183 | + trainer = Trainer(callbacks=callbacks) |
| 184 | +
|
166 | 185 | default_save_path (str): Default path for logs and weights when no logger/ckpt_callback passed
|
167 | 186 | Example::
|
168 | 187 |
|
@@ -579,6 +598,10 @@ def __init__(
|
579 | 598 |
|
580 | 599 | """
|
581 | 600 |
|
| 601 | + # Init callbacks |
| 602 | + self.callbacks = callbacks |
| 603 | + self.on_init_begin() |
| 604 | + |
582 | 605 | # Transfer params
|
583 | 606 | # Backward compatibility
|
584 | 607 | if nb_gpu_nodes is not None:
|
@@ -761,6 +784,8 @@ def __init__(
|
761 | 784 | use_amp = True
|
762 | 785 | self.init_amp(use_amp)
|
763 | 786 |
|
| 787 | + self.on_init_end() |
| 788 | + |
764 | 789 | @property
|
765 | 790 | def slurm_job_id(self):
|
766 | 791 | try:
|
@@ -890,6 +915,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
|
890 | 915 | _set_dataloader(model, val_dataloader, 'val_dataloader')
|
891 | 916 | _set_dataloader(model, test_dataloader, 'test_dataloader')
|
892 | 917 |
|
| 918 | + # Training begin callbacks |
| 919 | + self.on_fit_begin() |
| 920 | + |
893 | 921 | # when using multi-node or DDP within a node start each module in a separate process
|
894 | 922 | if self.use_ddp2:
|
895 | 923 | task = int(os.environ['SLURM_LOCALID'])
|
@@ -929,6 +957,9 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
|
929 | 957 |
|
930 | 958 | self.run_pretrain_routine(model)
|
931 | 959 |
|
| 960 | + # Training end callbacks |
| 961 | + self.on_fit_end() |
| 962 | + |
932 | 963 | # return 1 when finished
|
933 | 964 | # used for testing or when we need to know that training succeeded
|
934 | 965 | return 1
|
@@ -1082,6 +1113,7 @@ def test(self, model=None):
|
1082 | 1113 | trainer = Trainer()
|
1083 | 1114 | trainer.test(model)
|
1084 | 1115 | """
|
| 1116 | + |
1085 | 1117 | self.testing = True
|
1086 | 1118 | if model is not None:
|
1087 | 1119 | self.fit(model)
|
|
0 commit comments