14
14
3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called.
15
15
16
16
"""
17
-
17
+ from typing import Any
18
18
19
19
import torch
20
-
20
+ from torch import Tensor
21
+ from torch .optim .optimizer import Optimizer
21
22
22
23
try :
23
24
from apex import amp
@@ -36,48 +37,45 @@ def on_sanity_check_start(self):
36
37
:return:
37
38
"""
38
39
39
- def on_train_start (self ):
40
+ def on_train_start (self ) -> None :
40
41
"""Called at the beginning of training before sanity check
41
- :return:
42
42
"""
43
43
# do something at the start of training
44
44
45
- def on_train_end (self ):
45
+ def on_train_end (self ) -> None :
46
46
"""
47
47
Called at the end of training before logger experiment is closed
48
- :return:
49
48
"""
50
49
# do something at the end of training
51
50
52
- def on_batch_start (self , batch ) :
51
+ def on_batch_start (self , batch : Any ) -> None :
53
52
"""Called in the training loop before anything happens for that batch.
54
53
55
54
:param batch:
56
- :return:
57
55
"""
58
56
# do something when the batch starts
59
57
60
- def on_batch_end (self ):
58
+ def on_batch_end (self ) -> None :
61
59
"""Called in the training loop after the batch."""
62
60
# do something when the batch ends
63
61
64
- def on_epoch_start (self ):
62
+ def on_epoch_start (self ) -> None :
65
63
"""Called in the training loop at the very beginning of the epoch."""
66
64
# do something when the epoch starts
67
65
68
- def on_epoch_end (self ):
66
+ def on_epoch_end (self ) -> None :
69
67
"""Called in the training loop at the very end of the epoch."""
70
68
# do something when the epoch ends
71
69
72
- def on_pre_performance_check (self ):
70
+ def on_pre_performance_check (self ) -> None :
73
71
"""Called at the very beginning of the validation loop."""
74
72
# do something before validation starts
75
73
76
- def on_post_performance_check (self ):
74
+ def on_post_performance_check (self ) -> None :
77
75
"""Called at the very end of the validation loop."""
78
76
# do something before validation end
79
77
80
- def on_before_zero_grad (self , optimizer ) :
78
+ def on_before_zero_grad (self , optimizer : Optimizer ) -> None :
81
79
"""Called after optimizer.step() and before optimizer.zero_grad()
82
80
83
81
Called in the training loop after taking an optimizer step and before zeroing grads.
@@ -89,17 +87,13 @@ def on_before_zero_grad(self, optimizer):
89
87
model.on_before_zero_grad(optimizer) # < ---- called here
90
88
optimizer.zero_grad
91
89
92
- :param optimizer:
93
- :return:
90
+ :param optimizer: The optimizer for which grads should be zeroed.
94
91
"""
95
92
# do something with the optimizer or inspect it.
96
93
97
- def on_after_backward (self ):
98
- """Called after loss.backward() and before optimizers do anything.
99
-
100
- :return:
94
+ def on_after_backward (self ) -> None :
95
+ """Called in the training loop after loss.backward() and before optimizers do anything.
101
96
102
- Called in the training loop after model.backward()
103
97
This is the ideal place to inspect or log gradient information
104
98
105
99
.. code-block:: python
@@ -116,14 +110,13 @@ def on_after_backward(self):
116
110
117
111
"""
118
112
119
- def backward (self , trainer , loss , optimizer , optimizer_idx ) :
113
+ def backward (self , trainer , loss : Tensor , optimizer : Optimizer , optimizer_idx : int ) -> None :
120
114
"""Override backward with your own implementation if you need to
121
115
122
116
:param trainer: Pointer to the trainer
123
117
:param loss: Loss is already scaled by accumulated grads
124
118
:param optimizer: Current optimizer being used
125
119
:param optimizer_idx: Index of the current optimizer being used
126
- :return:
127
120
128
121
Called to perform backward step.
129
122
Feel free to override as needed.
0 commit comments