Skip to content

Commit bf6b963

Browse files
committed
move callback system to TrainerCallbackHookMixin
1 parent b9a04f5 commit bf6b963

File tree

6 files changed

+119
-53
lines changed

6 files changed

+119
-53
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

-16
Original file line numberDiff line numberDiff line change
@@ -166,29 +166,17 @@ def on_validation_end(self):
166166

167167
def _do_check_save(self, filepath, current, epoch):
168168
# remove kth
169-
<<<<<<< HEAD
170169
if len(self.best_k_models) == self.save_top_k:
171-
=======
172-
if len(self.best_k_models.keys()) == self.save_top_k:
173-
>>>>>>> fix logic error
174170
delpath = self.kth_best_model
175171
self.best_k_models.pop(self.kth_best_model)
176172
self._del_model(delpath)
177173

178174
self.best_k_models[filepath] = current
179-
<<<<<<< HEAD
180175
if len(self.best_k_models) == self.save_top_k:
181176
# monitor dict has reached k elements
182177
_op = max if self.mode == 'min' else min
183178
self.kth_best_model = _op(self.best_k_models,
184179
key=self.best_k_models.get)
185-
=======
186-
if len(self.best_k_models.keys()) == self.save_top_k:
187-
# monitor dict has reached k elements
188-
_op = max if self.mode == 'min' else min
189-
self.kth_best_model = _op(self.best_k_models,
190-
key=self.best_k_models.get)
191-
>>>>>>> fix logic error
192180
self.kth_value = self.best_k_models[self.kth_best_model]
193181

194182
_op = min if self.mode == 'min' else max
@@ -199,8 +187,4 @@ def _do_check_save(self, filepath, current, epoch):
199187
f'\nEpoch {epoch:05d}: {self.monitor} reached'
200188
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
201189
f' {filepath} as top {self.save_top_k}')
202-
<<<<<<< HEAD
203-
self._save_model(filepath)
204-
=======
205190
self._save_model(filepath)
206-
>>>>>>> fix logic error
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
from abc import ABC
3+
4+
from pytorch_lightning.callbacks import Callback
5+
6+
7+
class TrainerCallbackHookMixin(ABC):
8+
9+
def __init__(self):
10+
# this is just a summary on variables used in this abstract class,
11+
# the proper values/initialisation should be done in child class
12+
self.callbacks: list[Callback] = []
13+
14+
def on_init_begin(self):
15+
"""Called when the trainer initialization begins."""
16+
for callback in self.callbacks:
17+
callback.set_trainer(self)
18+
callback.on_init_begin()
19+
20+
def on_init_end(self):
21+
"""Called when the trainer initialization ends."""
22+
for callback in self.callbacks:
23+
callback.on_init_end()
24+
25+
def on_fit_begin(self):
26+
"""Called when the fit begins."""
27+
for callback in self.callbacks:
28+
callback.on_fit_begin()
29+
30+
def on_fit_end(self):
31+
"""Called when the fit ends."""
32+
for callback in self.callbacks:
33+
callback.on_fit_end()
34+
35+
def on_epoch_begin(self):
36+
"""Called when the epoch begins."""
37+
for callback in self.callbacks:
38+
callback.on_epoch_begin()
39+
40+
def on_epoch_end(self):
41+
"""Called when the epoch ends."""
42+
for callback in self.callbacks:
43+
callback.on_epoch_end()
44+
45+
def on_train_begin(self):
46+
"""Called when the train begins."""
47+
for callback in self.callbacks:
48+
callback.on_train_begin()
49+
50+
def on_train_end(self):
51+
"""Called when the train ends."""
52+
for callback in self.callbacks:
53+
callback.on_train_end()
54+
55+
def on_batch_begin(self):
56+
"""Called when the training batch begins."""
57+
for callback in self.callbacks:
58+
callback.on_batch_begin()
59+
60+
def on_batch_end(self):
61+
"""Called when the training batch ends."""
62+
for callback in self.callbacks:
63+
callback.on_batch_end()
64+
65+
def on_validation_begin(self):
66+
"""Called when the validation loop begins."""
67+
for callback in self.callbacks:
68+
callback.on_validation_begin()
69+
70+
def on_validation_end(self):
71+
"""Called when the validation loop ends."""
72+
for callback in self.callbacks:
73+
callback.on_validation_end()
74+
75+
def on_test_begin(self):
76+
"""Called when the test begins."""
77+
for callback in self.callbacks:
78+
callback.on_test_begin()
79+
80+
def on_test_end(self):
81+
"""Called when the test ends."""
82+
for callback in self.callbacks:
83+
callback.on_test_end()

pytorch_lightning/trainer/evaluation_loop.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,12 @@ def __init__(self):
168168
self.get_test_dataloaders = None
169169
self.get_val_dataloaders = None
170170
self.use_tpu = None
171-
self.callbacks = []
171+
172+
# Callback system
173+
self.on_validation_begin = None
174+
self.on_validation_end = None
175+
self.on_test_begin = None
176+
self.on_test_end = None
172177

173178
@abstractmethod
174179
def copy_trainer_model_properties(self, model):
@@ -295,11 +300,10 @@ def run_evaluation(self, test=False):
295300
raise MisconfigurationException(m)
296301

297302
# Validation/Test begin callbacks
298-
for callback in self.callbacks:
299-
if test:
300-
callback.on_test_begin()
301-
else:
302-
callback.on_validation_begin()
303+
if test:
304+
self.on_test_begin()
305+
else:
306+
self.on_validation_begin()
303307

304308
# hook
305309
model = self.get_model()
@@ -362,11 +366,10 @@ def run_evaluation(self, test=False):
362366
self.checkpoint_callback.on_validation_end()
363367

364368
# Validation/Test end callbacks
365-
for callback in self.callbacks:
366-
if test:
367-
callback.on_test_end()
368-
else:
369-
callback.on_validation_end()
369+
if test:
370+
self.on_test_end()
371+
else:
372+
self.on_validation_end()
370373

371374
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
372375
# make dataloader_idx arg in validation_step optional

pytorch_lightning/trainer/trainer.py

+7-11
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.trainer.training_io import TrainerIOMixin
2626
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
2727
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
28+
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
2829
from pytorch_lightning.utilities.debugging import MisconfigurationException
2930
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
3031

@@ -57,6 +58,7 @@ class Trainer(TrainerIOMixin,
5758
TrainerEvaluationLoopMixin,
5859
TrainerTrainLoopMixin,
5960
TrainerCallbackConfigMixin,
61+
TrainerCallbackHookMixin
6062
):
6163

6264
def __init__(
@@ -596,11 +598,9 @@ def on_train_end(self):
596598
597599
"""
598600

599-
# Init start callbacks
601+
# Init callbacks
600602
self.callbacks = callbacks
601-
for callback in self.callbacks:
602-
callback.set_trainer(self)
603-
callback.on_init_begin()
603+
self.on_init_begin()
604604

605605
# Transfer params
606606
# Backward compatibility
@@ -784,9 +784,7 @@ def on_train_end(self):
784784
use_amp = True
785785
self.init_amp(use_amp)
786786

787-
# Init end callbacks
788-
for callback in self.callbacks:
789-
callback.on_init_end()
787+
self.on_init_end()
790788

791789
@property
792790
def slurm_job_id(self):
@@ -918,8 +916,7 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
918916
_set_dataloader(model, test_dataloader, 'test_dataloader')
919917

920918
# Training begin callbacks
921-
for callback in self.callbacks:
922-
callback.on_fit_begin()
919+
self.on_fit_begin()
923920

924921
# when using multi-node or DDP within a node start each module in a separate process
925922
if self.use_ddp2:
@@ -961,8 +958,7 @@ def fit(self, model, train_dataloader=None, val_dataloader=None, test_dataloader
961958
self.run_pretrain_routine(model)
962959

963960
# Training end callbacks
964-
for callback in self.callbacks:
965-
callback.on_fit_end()
961+
self.on_fit_end()
966962

967963
# return 1 when finished
968964
# used for testing or when we need to know that training succeeded

pytorch_lightning/trainer/training_loop.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,15 @@ def __init__(self):
231231
self.batch_idx = None
232232
self.precision = None
233233
self.callbacks = []
234+
self.max_steps = None
235+
236+
# Callback system
237+
self.on_train_begin = None
238+
self.on_train_end = None
239+
self.on_batch_begin = None
240+
self.on_batch_end = None
241+
self.on_epoch_begin = None
242+
self.on_epoch_end = None
234243

235244
@property
236245
def max_nb_epochs(self):
@@ -308,8 +317,7 @@ def process_output(self, output, train):
308317
def train(self):
309318

310319
# Train begin callbacks
311-
for callback in self.callbacks:
312-
callback.on_train_begin()
320+
self.on_train_begin()
313321

314322
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
315323
' but will start from "0" in v0.8.0.', DeprecationWarning)
@@ -404,17 +412,15 @@ def train(self):
404412
model.on_train_end()
405413

406414
# Train end callbacks
407-
for callback in self.callbacks:
408-
callback.on_train_end()
415+
self.on_train_end()
409416

410417
if self.logger is not None:
411418
self.logger.finalize("success")
412419

413420
def run_training_epoch(self):
414421

415422
# Epoch begin callbacks
416-
for callback in self.callbacks:
417-
callback.on_epoch_begin()
423+
self.on_epoch_begin()
418424

419425
# before epoch hook
420426
if self.is_function_implemented('on_epoch_start'):
@@ -502,8 +508,7 @@ def run_training_epoch(self):
502508
model.on_epoch_end()
503509

504510
# Epoch begin callbacks
505-
for callback in self.callbacks:
506-
callback.on_epoch_end()
511+
self.on_epoch_end()
507512

508513
def run_training_batch(self, batch, batch_idx):
509514
# track grad norms
@@ -519,8 +524,7 @@ def run_training_batch(self, batch, batch_idx):
519524
return 0, grad_norm_dic, {}
520525

521526
# Batch begin callbacks
522-
for callback in self.callbacks:
523-
callback.on_batch_begin()
527+
self.on_batch_begin()
524528

525529
# hook
526530
if self.is_function_implemented('on_batch_start'):
@@ -634,8 +638,7 @@ def optimizer_closure():
634638
model.on_batch_end()
635639

636640
# Batch end callbacks
637-
for callback in self.callbacks:
638-
callback.on_batch_end()
641+
self.on_batch_end()
639642

640643
# update progress bar
641644
self.main_progress_bar.update(1)

tests/test_trainer.py

-3
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,6 @@ class CurrentTestModel(
453453
trainer.test()
454454

455455

456-
<<<<<<< HEAD
457456
def test_train_dataloaders_passed_to_fit(tmpdir):
458457
""" Verify that train dataloader can be passed to fit """
459458
tutils.reset_seed()
@@ -697,7 +696,6 @@ def test_trainer_min_steps_and_epochs(tmpdir):
697696
assert trainer.global_step >= math.floor(num_train_samples * 1.5) and \
698697
trainer.current_epoch > 0, "Model did not train for at least min_steps"
699698

700-
=======
701699
def test_callbacks():
702700
"""Test callbacks mechanics."""
703701
tutils.reset_seed()
@@ -808,7 +806,6 @@ def on_test_end(self):
808806

809807
assert test_callback.test_begin_called
810808
assert test_callback.test_end_called
811-
>>>>>>> add callbacks arguments to the trainer + associated tests
812809

813810
# if __name__ == '__main__':
814811
# pytest.main([__file__])

0 commit comments

Comments
 (0)