|
30 | 30 | from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
31 | 31 | from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
32 | 32 | from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
| 33 | +from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin |
33 | 34 | from pytorch_lightning.utilities.debugging import MisconfigurationException
|
34 | 35 | from pytorch_lightning.profiler import Profiler, PassThroughProfiler
|
| 36 | +from pytorch_lightning.callbacks import Callback |
35 | 37 |
|
36 | 38 |
|
37 | 39 | try:
|
@@ -62,13 +64,15 @@ class Trainer(TrainerIOMixin,
|
62 | 64 | TrainerEvaluationLoopMixin,
|
63 | 65 | TrainerTrainLoopMixin,
|
64 | 66 | TrainerCallbackConfigMixin,
|
| 67 | + TrainerCallbackHookMixin |
65 | 68 | ):
|
66 | 69 |
|
67 | 70 | def __init__(
|
68 | 71 | self,
|
69 | 72 | logger: Union[LightningLoggerBase, bool] = True,
|
70 | 73 | checkpoint_callback: Union[ModelCheckpoint, bool] = True,
|
71 | 74 | early_stop_callback: Optional[Union[EarlyStopping, bool]] = None,
|
| 75 | + callbacks: list = [], |
72 | 76 | default_save_path: Optional[str] = None,
|
73 | 77 | gradient_clip_val: float = 0,
|
74 | 78 | gradient_clip=None, # backward compatible, todo: remove in v0.8.0
|
@@ -168,6 +172,18 @@ def __init__(
|
168 | 172 |
|
169 | 173 | trainer = Trainer(early_stop_callback=early_stop_callback)
|
170 | 174 |
|
| 175 | + callback (:class:`.Callback`): Add a list of callbacks. |
| 176 | + Example:: |
| 177 | + from pytorch_lightning.callbacks import Callback |
| 178 | + class PrintCallback(Callback): |
| 179 | + def on_train_begin(self): |
| 180 | + print("Training is started!") |
| 181 | + def on_train_end(self): |
| 182 | + print(f"Training is done. The logs are: {self.trainer.logs}") |
| 183 | + # a list of callbacks |
| 184 | + callbacks = [PrintCallback()] |
| 185 | + trainer = Trainer(callbacks=callbacks) |
| 186 | +
|
171 | 187 | default_save_path: Default path for logs and weights when no logger/ckpt_callback passed
|
172 | 188 | Example::
|
173 | 189 |
|
@@ -584,6 +600,10 @@ def __init__(
|
584 | 600 |
|
585 | 601 | """
|
586 | 602 |
|
| 603 | + # Init callbacks |
| 604 | + self.callbacks = callbacks |
| 605 | + self.on_init_begin() |
| 606 | + |
587 | 607 | # Transfer params
|
588 | 608 | # Backward compatibility
|
589 | 609 | if nb_gpu_nodes is not None:
|
@@ -766,6 +786,9 @@ def __init__(
|
766 | 786 | use_amp = True
|
767 | 787 | self.init_amp(use_amp)
|
768 | 788 |
|
| 789 | + # Callback system |
| 790 | + self.on_init_end() |
| 791 | + |
769 | 792 | @property
|
770 | 793 | def slurm_job_id(self) -> int:
|
771 | 794 | try:
|
@@ -901,6 +924,9 @@ def fit(
|
901 | 924 | _set_dataloader(model, val_dataloader, 'val_dataloader')
|
902 | 925 | _set_dataloader(model, test_dataloader, 'test_dataloader')
|
903 | 926 |
|
| 927 | + # Fit begin callbacks |
| 928 | + self.on_fit_begin() |
| 929 | + |
904 | 930 | # when using multi-node or DDP within a node start each module in a separate process
|
905 | 931 | if self.use_ddp2:
|
906 | 932 | task = int(os.environ['SLURM_LOCALID'])
|
@@ -940,6 +966,9 @@ def fit(
|
940 | 966 |
|
941 | 967 | self.run_pretrain_routine(model)
|
942 | 968 |
|
| 969 | + # Fit end callbacks |
| 970 | + self.on_fit_end() |
| 971 | + |
943 | 972 | # return 1 when finished
|
944 | 973 | # used for testing or when we need to know that training succeeded
|
945 | 974 | return 1
|
@@ -1034,8 +1063,8 @@ def run_pretrain_routine(self, model: LightningModule):
|
1034 | 1063 | return
|
1035 | 1064 |
|
1036 | 1065 | # check if we should run validation during training
|
1037 |
| - self.disable_validation = ((self.num_val_batches == 0 or |
1038 |
| - not self.is_overriden('validation_step')) and |
| 1066 | + self.disable_validation = ((self.num_val_batches == 0 |
| 1067 | + or not self.is_overriden('validation_step')) and |
1039 | 1068 | not self.fast_dev_run)
|
1040 | 1069 |
|
1041 | 1070 | # run tiny validation (if validation defined)
|
@@ -1139,7 +1168,7 @@ def _set_dataloader(model, dataloader, attribute):
|
1139 | 1168 | if is_dataloader or is_dataloader_list and valid_loaders:
|
1140 | 1169 |
|
1141 | 1170 | # Overwrite abstract methods
|
1142 |
| - dl = lambda: dataloader |
| 1171 | + def dl(): return dataloader |
1143 | 1172 | dl.__name__ = attribute
|
1144 | 1173 | setattr(model, attribute, dl)
|
1145 | 1174 |
|
|
0 commit comments