File tree 6 files changed +52
-13
lines changed
6 files changed +52
-13
lines changed Original file line number Diff line number Diff line change 3
3
4
4
Callbacks
5
5
=========
6
+
7
+ Lightning has a callback system to execute arbitrary code. Callbacks should capture NON-ESSENTIAL
8
+ logic that is NOT required for your LightningModule to run.
9
+
10
+ An overall Lightning system should have:
11
+
12
+ 1. Trainer for all engineering
13
+ 2. LightningModule for all research code.
14
+ 3. Callbacks for non-essential code.
15
+
16
+ Example
17
+
18
+ .. code-block :: python
19
+
20
+ import pytorch_lightning as pl
21
+
22
+ class MyPrintingCallback (pl .Callback ):
23
+
24
+ def on_init_start (self , trainer ):
25
+ print (' Starting to init trainer!' )
26
+
27
+ def on_init_end (self , trainer ):
28
+ print (' trainer is init now' )
29
+
30
+ def on_train_end (self , trainer , pl_module ):
31
+ print (' do something when training ends' )
32
+
33
+ # pass to trainer
34
+ trainer = pl.Trainer(callbacks = [MyPrintingCallback()])
35
+
36
+ We successfully extended functionality without polluting our super clean LightningModule research code
37
+
38
+ Callback Class
39
+ --------------
40
+
6
41
.. automodule :: pytorch_lightning.callbacks
7
42
:noindex:
8
43
:exclude-members:
9
44
_del_model,
10
45
_save_model,
46
+ _abc_impl,
11
47
on_epoch_end,
12
48
on_train_end,
13
49
on_epoch_start,
Original file line number Diff line number Diff line change
1
+ Hooks
2
+ -----
3
+
1
4
.. automodule :: pytorch_lightning.core.hooks
2
5
3
6
Full list of hooks
4
- ------------------
7
+
5
8
6
9
Training set-up
7
- ===============
10
+ ================
8
11
- init_ddp_connection
9
12
- init_optimizers
10
13
- configure_apex
Original file line number Diff line number Diff line change 11
11
class Callback (abc .ABC ):
12
12
"""Abstract base class used to build new callbacks."""
13
13
14
- def on_init_start (self , trainer , pl_module ):
14
+ def on_init_start (self , trainer ):
15
15
"""Called when the trainer initialization begins."""
16
- assert pl_module is None
16
+ pass
17
17
18
- def on_init_end (self , trainer , pl_module ):
18
+ def on_init_end (self , trainer ):
19
19
"""Called when the trainer initialization ends."""
20
20
pass
21
21
Original file line number Diff line number Diff line change @@ -12,15 +12,15 @@ def __init__(self):
12
12
self .callbacks : list [Callback ] = []
13
13
self .get_model : Callable = ...
14
14
15
- def on_init_start (self ):
15
+ def on_init_start (self , trainer ):
16
16
"""Called when the trainer initialization begins."""
17
17
for callback in self .callbacks :
18
- callback .on_init_start (self , None )
18
+ callback .on_init_start (trainer )
19
19
20
- def on_init_end (self ):
20
+ def on_init_end (self , trainer ):
21
21
"""Called when the trainer initialization ends."""
22
22
for callback in self .callbacks :
23
- callback .on_init_end (self , self . get_model () )
23
+ callback .on_init_end (trainer )
24
24
25
25
def on_fit_start (self ):
26
26
"""Called when the fit begins."""
Original file line number Diff line number Diff line change @@ -618,7 +618,7 @@ def on_train_end(self):
618
618
619
619
# Init callbacks
620
620
self .callbacks = callbacks
621
- self .on_init_start ()
621
+ self .on_init_start (self )
622
622
623
623
# benchmarking
624
624
self .benchmark = benchmark
@@ -808,7 +808,7 @@ def on_train_end(self):
808
808
self .init_amp (use_amp )
809
809
810
810
# Callback system
811
- self .on_init_end ()
811
+ self .on_init_end (self )
812
812
813
813
@property
814
814
def slurm_job_id (self ) -> int :
Original file line number Diff line number Diff line change @@ -630,10 +630,10 @@ def __init__(self):
630
630
self .on_test_start_called = False
631
631
self .on_test_end_called = False
632
632
633
- def on_init_start (self , trainer , pl_module ):
633
+ def on_init_start (self , trainer ):
634
634
self .on_init_start_called = True
635
635
636
- def on_init_end (self , trainer , pl_module ):
636
+ def on_init_end (self , trainer ):
637
637
self .on_init_end_called = True
638
638
639
639
def on_fit_start (self , trainer , pl_module ):
You can’t perform that action at this time.
0 commit comments