Skip to content

Commit 705e576

Browse files
authored
consolidate callbacks and hooks (#950)
* consolidate callbacks and hooks * ensure callbacks recieve proper arg types * remove model from init callback events * clean up early stopping event * update changelog * remove on_fit_start and on_fit_end * fix args for on_init_start and on_init_end * handle case where early stopping is not used * show all callback methods * wrap checkpoint callback logic into proper class * fix check for main process in checkpoint callback * move callbacks test to separate file * refactor arg checks * get model and call hook on same line * define trainer_options dict in one call * add more asserts to callback test
1 parent 1789165 commit 705e576

11 files changed

+220
-214
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2222
- Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868))
2323
- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876))
2424
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
25+
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
2526
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
2627

2728
### Changed

docs/source/callbacks.rst

-4
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,4 @@ Callback Class
4444
_del_model,
4545
_save_model,
4646
_abc_impl,
47-
on_epoch_end,
48-
on_train_end,
49-
on_epoch_start,
5047
check_monitor_top_k,
51-
on_train_start,

pytorch_lightning/callbacks/base.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,11 @@ class Callback(abc.ABC):
1212
"""Abstract base class used to build new callbacks."""
1313

1414
def on_init_start(self, trainer):
15-
"""Called when the trainer initialization begins."""
15+
"""Called when the trainer initialization begins, model has not yet been set."""
1616
pass
1717

1818
def on_init_end(self, trainer):
19-
"""Called when the trainer initialization ends."""
20-
pass
21-
22-
def on_fit_start(self, trainer, pl_module):
23-
"""Called when the fit begins."""
24-
pass
25-
26-
def on_fit_end(self, trainer, pl_module):
27-
"""Called when the fit ends."""
19+
"""Called when the trainer initialization ends, model has not yet been set."""
2820
pass
2921

3022
def on_epoch_start(self, trainer, pl_module):

pytorch_lightning/callbacks/early_stopping.py

-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
6464
self.monitor_op = mode_dict[mode]
6565
self.min_delta *= 1 if self.monitor_op == np.greater else -1
6666

67-
self.on_train_start(None, None)
68-
6967
def check_metrics(self, logs):
7068
monitor_val = logs.get(self.monitor)
7169
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'

pytorch_lightning/callbacks/model_checkpoint.py

+4
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def check_monitor_top_k(self, current):
118118
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
119119

120120
def on_validation_end(self, trainer, pl_module):
121+
# only run on main process
122+
if trainer.proc_rank != 0:
123+
return
124+
121125
logs = trainer.callback_metrics
122126
epoch = trainer.current_epoch
123127
self.epochs_since_last_check += 1

pytorch_lightning/trainer/callback_hook.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,15 @@ def __init__(self):
1212
self.callbacks: list[Callback] = []
1313
self.get_model: Callable = ...
1414

15-
def on_init_start(self, trainer):
16-
"""Called when the trainer initialization begins."""
15+
def on_init_start(self):
16+
"""Called when the trainer initialization begins, model has not yet been set."""
1717
for callback in self.callbacks:
18-
callback.on_init_start(trainer)
18+
callback.on_init_start(self)
1919

20-
def on_init_end(self, trainer):
21-
"""Called when the trainer initialization ends."""
20+
def on_init_end(self):
21+
"""Called when the trainer initialization ends, model has not yet been set."""
2222
for callback in self.callbacks:
23-
callback.on_init_end(trainer)
24-
25-
def on_fit_start(self):
26-
"""Called when the fit begins."""
27-
for callback in self.callbacks:
28-
callback.on_fit_start(self, self.get_model())
29-
30-
def on_fit_end(self):
31-
"""Called when the fit ends."""
32-
for callback in self.callbacks:
33-
callback.on_fit_end(self, self.get_model())
23+
callback.on_init_end(self)
3424

3525
def on_epoch_start(self):
3626
"""Called when the epoch begins."""

pytorch_lightning/trainer/evaluation_loop.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -374,14 +374,13 @@ def run_evaluation(self, test_mode: bool = False):
374374
else:
375375
self.val_progress_bar.close()
376376

377-
# model checkpointing
378-
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
379-
self.checkpoint_callback.on_validation_end(self, self.get_model())
380-
381377
# Validation/Test end callbacks
382378
if test_mode:
383379
self.on_test_end()
384380
else:
381+
# model checkpointing
382+
if self.checkpoint_callback is not None:
383+
self.checkpoint_callback.on_validation_end(self, self.get_model())
385384
self.on_validation_end()
386385

387386
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):

pytorch_lightning/trainer/trainer.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def on_train_end(self):
618618

619619
# Init callbacks
620620
self.callbacks = callbacks
621-
self.on_init_start(self)
621+
self.on_init_start()
622622

623623
# benchmarking
624624
self.benchmark = benchmark
@@ -808,7 +808,7 @@ def on_train_end(self):
808808
self.init_amp(use_amp)
809809

810810
# Callback system
811-
self.on_init_end(self)
811+
self.on_init_end()
812812

813813
@property
814814
def slurm_job_id(self) -> int:
@@ -941,9 +941,6 @@ def fit(
941941
# bind logger
942942
model.logger = self.logger
943943

944-
# Fit begin callbacks
945-
self.on_fit_start()
946-
947944
# set up the passed in dataloaders (if needed)
948945
self.__set_fit_dataloaders(model, train_dataloader, val_dataloaders, test_dataloaders)
949946

@@ -1006,9 +1003,6 @@ def fit(
10061003

10071004
self.run_pretrain_routine(model)
10081005

1009-
# Fit end callbacks
1010-
self.on_fit_end()
1011-
10121006
# return 1 when finished
10131007
# used for testing or when we need to know that training succeeded
10141008
return 1

pytorch_lightning/trainer/training_loop.py

+49-51
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,15 @@ def train(self):
302302
self.reset_train_dataloader(model)
303303
self.reset_val_dataloader(model)
304304

305-
# Train begin callbacks
306-
model.on_train_start()
307-
self.on_train_start()
305+
# Train start events
306+
with self.profiler.profile('on_train_start'):
307+
# callbacks
308+
self.on_train_start()
309+
# initialize early stop callback
310+
if self.early_stop_callback is not None:
311+
self.early_stop_callback.on_train_start(self, self.get_model())
312+
# model hooks
313+
model.on_train_start()
308314

309315
try:
310316
# run all epochs
@@ -347,9 +353,6 @@ def train(self):
347353
desc = f'Epoch {epoch + 1}' if not self.is_infinite_dataloader(self.train_dataloader) else ''
348354
self.main_progress_bar.set_description(desc)
349355

350-
# changing gradient according accumulation_scheduler
351-
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
352-
353356
# -----------------
354357
# RUN TNG EPOCH
355358
# -----------------
@@ -369,23 +372,21 @@ def train(self):
369372
self.reduce_lr_on_plateau_scheduler.step(val_loss)
370373

371374
if self.max_steps and self.max_steps == self.global_step:
372-
self.main_progress_bar.close()
373-
model.on_train_end()
374-
self.on_train_end()
375+
self.run_training_teardown()
375376
return
376377

377378
# early stopping
378379
met_min_epochs = epoch >= self.min_epochs - 1
379380
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
380381

382+
# TODO wrap this logic into the callback
381383
if self.enable_early_stop and not self.disable_validation and is_val_epoch:
382384
if ((met_min_epochs and met_min_steps) or self.fast_dev_run):
383385
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
384386
# stop training
385387
stop = should_stop and met_min_epochs
386388
if stop:
387389
self.run_training_teardown()
388-
self.on_train_end()
389390
return
390391

391392
self.run_training_teardown()
@@ -394,19 +395,17 @@ def train(self):
394395
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
395396
self.run_training_teardown()
396397

397-
# Train end callbacks
398-
self.on_train_end()
399-
400398
def run_training_epoch(self):
401399

402-
# Epoch begin callbacks
403-
self.on_epoch_start()
404-
405-
# before epoch hook
406-
if self.is_function_implemented('on_epoch_start'):
407-
model = self.get_model()
408-
with self.profiler.profile('on_epoch_start'):
409-
model.on_epoch_start()
400+
# Epoch start events
401+
with self.profiler.profile('on_epoch_start'):
402+
# callbacks
403+
self.on_epoch_start()
404+
# changing gradient according accumulation_scheduler
405+
self.accumulation_scheduler.on_epoch_start(self, self.get_model())
406+
# model hooks
407+
if self.is_function_implemented('on_epoch_start'):
408+
self.get_model().on_epoch_start()
410409

411410
# reset train dataloader
412411
if self.reload_dataloaders_every_epoch:
@@ -485,14 +484,13 @@ def run_training_epoch(self):
485484
if early_stop_epoch or self.fast_dev_run:
486485
break
487486

488-
# epoch end hook
489-
if self.is_function_implemented('on_epoch_end'):
490-
model = self.get_model()
491-
with self.profiler.profile('on_epoch_end'):
492-
model.on_epoch_end()
493-
494-
# Epoch begin callbacks
495-
self.on_epoch_end()
487+
# Epoch end events
488+
with self.profiler.profile('on_epoch_end'):
489+
# callbacks
490+
self.on_epoch_end()
491+
# model hooks
492+
if self.is_function_implemented('on_epoch_end'):
493+
self.get_model().on_epoch_end()
496494

497495
def run_training_batch(self, batch, batch_idx):
498496
# track grad norms
@@ -507,17 +505,15 @@ def run_training_batch(self, batch, batch_idx):
507505
if batch is None:
508506
return 0, grad_norm_dic, {}
509507

510-
# Batch begin callbacks
511-
self.on_batch_start()
512-
513-
# hook
514-
if self.is_function_implemented('on_batch_start'):
515-
model_ref = self.get_model()
516-
with self.profiler.profile('on_batch_start'):
517-
response = model_ref.on_batch_start(batch)
518-
519-
if response == -1:
520-
return -1, grad_norm_dic, {}
508+
# Batch start events
509+
with self.profiler.profile('on_batch_start'):
510+
# callbacks
511+
self.on_batch_start()
512+
# hooks
513+
if self.is_function_implemented('on_batch_start'):
514+
response = self.get_model().on_batch_start(batch)
515+
if response == -1:
516+
return -1, grad_norm_dic, {}
521517

522518
splits = [batch]
523519
if self.truncated_bptt_steps is not None:
@@ -612,14 +608,13 @@ def optimizer_closure():
612608
self.batch_loss_value = 0
613609
self.avg_loss = np.mean(self.running_loss[-100:])
614610

615-
# activate batch end hook
616-
if self.is_function_implemented('on_batch_end'):
617-
model = self.get_model()
618-
with self.profiler.profile('on_batch_end'):
619-
model.on_batch_end()
620-
621-
# Batch end callbacks
622-
self.on_batch_end()
611+
# Batch end events
612+
with self.profiler.profile('on_batch_end'):
613+
# callbacks
614+
self.on_batch_end()
615+
# model hooks
616+
if self.is_function_implemented('on_batch_end'):
617+
self.get_model().on_batch_end()
623618

624619
# update progress bar
625620
if batch_idx % self.progress_bar_refresh_rate == 0:
@@ -635,12 +630,15 @@ def optimizer_closure():
635630
return 0, grad_norm_dic, all_log_metrics
636631

637632
def run_training_teardown(self):
638-
model = self.get_model()
639-
640633
self.main_progress_bar.close()
641634

635+
# Train end events
642636
with self.profiler.profile('on_train_end'):
643-
model.on_train_end()
637+
# callbacks
638+
self.on_train_end()
639+
# model hooks
640+
if self.is_function_implemented('on_train_end'):
641+
self.get_model().on_train_end()
644642

645643
if self.logger is not None:
646644
self.logger.finalize("success")

0 commit comments

Comments
 (0)