Skip to content

Commit 6b62186

Browse files
williamFalcontullie
authored andcommitted
clean docs (Lightning-AI#967)
* clean docs * clean docs * clean docs
1 parent 98da6fc commit 6b62186

File tree

6 files changed

+52
-13
lines changed

6 files changed

+52
-13
lines changed

docs/source/callbacks.rst

+36
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,47 @@
33

44
Callbacks
55
=========
6+
7+
Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
8+
logic that is NOT required for your LightningModule to run.
9+
10+
An overall Lightning system should have:
11+
12+
1. Trainer for all engineering
13+
2. LightningModule for all research code.
14+
3. Callbacks for non-essential code.
15+
16+
Example
17+
18+
.. code-block:: python
19+
20+
import pytorch_lightning as pl
21+
22+
class MyPrintingCallback(pl.Callback):
23+
24+
def on_init_start(self, trainer):
25+
print('Starting to init trainer!')
26+
27+
def on_init_end(self, trainer):
28+
print('trainer is init now')
29+
30+
def on_train_end(self, trainer, pl_module):
31+
print('do something when training ends')
32+
33+
# pass to trainer
34+
trainer = pl.Trainer(callbacks=[MyPrintingCallback()])
35+
36+
We successfully extended functionality without polluting our super clean LightningModule research code
37+
38+
Callback Class
39+
--------------
40+
641
.. automodule:: pytorch_lightning.callbacks
742
:noindex:
843
:exclude-members:
944
_del_model,
1045
_save_model,
46+
_abc_impl,
1147
on_epoch_end,
1248
on_train_end,
1349
on_epoch_start,

docs/source/hooks.rst

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
Hooks
2+
-----
3+
14
.. automodule:: pytorch_lightning.core.hooks
25

36
Full list of hooks
4-
------------------
7+
58

69
Training set-up
7-
===============
10+
================
811
- init_ddp_connection
912
- init_optimizers
1013
- configure_apex

pytorch_lightning/callbacks/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
class Callback(abc.ABC):
1212
"""Abstract base class used to build new callbacks."""
1313

14-
def on_init_start(self, trainer, pl_module):
14+
def on_init_start(self, trainer):
1515
"""Called when the trainer initialization begins."""
16-
assert pl_module is None
16+
pass
1717

18-
def on_init_end(self, trainer, pl_module):
18+
def on_init_end(self, trainer):
1919
"""Called when the trainer initialization ends."""
2020
pass
2121

pytorch_lightning/trainer/callback_hook.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ def __init__(self):
1212
self.callbacks: list[Callback] = []
1313
self.get_model: Callable = ...
1414

15-
def on_init_start(self):
15+
def on_init_start(self, trainer):
1616
"""Called when the trainer initialization begins."""
1717
for callback in self.callbacks:
18-
callback.on_init_start(self, None)
18+
callback.on_init_start(trainer)
1919

20-
def on_init_end(self):
20+
def on_init_end(self, trainer):
2121
"""Called when the trainer initialization ends."""
2222
for callback in self.callbacks:
23-
callback.on_init_end(self, self.get_model())
23+
callback.on_init_end(trainer)
2424

2525
def on_fit_start(self):
2626
"""Called when the fit begins."""

pytorch_lightning/trainer/trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def on_train_end(self):
618618

619619
# Init callbacks
620620
self.callbacks = callbacks
621-
self.on_init_start()
621+
self.on_init_start(self)
622622

623623
# benchmarking
624624
self.benchmark = benchmark
@@ -808,7 +808,7 @@ def on_train_end(self):
808808
self.init_amp(use_amp)
809809

810810
# Callback system
811-
self.on_init_end()
811+
self.on_init_end(self)
812812

813813
@property
814814
def slurm_job_id(self) -> int:

tests/trainer/test_trainer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,10 @@ def __init__(self):
630630
self.on_test_start_called = False
631631
self.on_test_end_called = False
632632

633-
def on_init_start(self, trainer, pl_module):
633+
def on_init_start(self, trainer):
634634
self.on_init_start_called = True
635635

636-
def on_init_end(self, trainer, pl_module):
636+
def on_init_end(self, trainer):
637637
self.on_init_end_called = True
638638

639639
def on_fit_start(self, trainer, pl_module):

0 commit comments

Comments
 (0)