@@ -14,12 +14,12 @@ class TrainerCallbackHookMixin(ABC):
14
14
def setup (self , stage : str ):
15
15
"""Called in the beginning of fit and test"""
16
16
for callback in self .callbacks :
17
- callback .setup (self , stage )
17
+ callback .setup (self , self . get_model (), stage )
18
18
19
19
def teardown (self , stage : str ):
20
20
"""Called at the end of fit and test"""
21
21
for callback in self .callbacks :
22
- callback .teardown (self , stage )
22
+ callback .teardown (self , self . get_model (), stage )
23
23
24
24
def on_init_start (self ):
25
25
"""Called when the trainer initialization begins, model has not yet been set."""
@@ -31,15 +31,15 @@ def on_init_end(self):
31
31
for callback in self .callbacks :
32
32
callback .on_init_end (self )
33
33
34
- def on_fit_start (self ):
34
+ def on_fit_start (self , model ):
35
35
"""Called when the trainer initialization begins, model has not yet been set."""
36
36
for callback in self .callbacks :
37
- callback .on_fit_start (self )
37
+ callback .on_fit_start (self , model )
38
38
39
39
def on_fit_end (self ):
40
40
"""Called when the trainer initialization begins, model has not yet been set."""
41
41
for callback in self .callbacks :
42
- callback .on_fit_end (self )
42
+ callback .on_fit_end (self , self . get_model () )
43
43
44
44
def on_sanity_check_start (self ):
45
45
"""Called when the validation sanity check starts."""
@@ -101,6 +101,16 @@ def on_train_end(self):
101
101
for callback in self .callbacks :
102
102
callback .on_train_end (self , self .get_model ())
103
103
104
+ def on_pretrain_routine_start (self , model ):
105
+ """Called when the train begins."""
106
+ for callback in self .callbacks :
107
+ callback .on_pretrain_routine_start (self , model )
108
+
109
+ def on_pretrain_routine_end (self , model ):
110
+ """Called when the train ends."""
111
+ for callback in self .callbacks :
112
+ callback .on_pretrain_routine_end (self , model )
113
+
104
114
def on_batch_start (self ):
105
115
"""Called when the training batch begins."""
106
116
for callback in self .callbacks :
@@ -111,35 +121,35 @@ def on_batch_end(self):
111
121
for callback in self .callbacks :
112
122
callback .on_batch_end (self , self .get_model ())
113
123
114
- def on_train_batch_start (self ):
124
+ def on_train_batch_start (self , batch , batch_idx , dataloader_idx ):
115
125
"""Called when the training batch begins."""
116
126
for callback in self .callbacks :
117
- callback .on_train_batch_start (self , self .get_model ())
127
+ callback .on_train_batch_start (self , self .get_model (), batch , batch_idx , dataloader_idx )
118
128
119
- def on_train_batch_end (self ):
129
+ def on_train_batch_end (self , batch , batch_idx , dataloader_idx ):
120
130
"""Called when the training batch ends."""
121
131
for callback in self .callbacks :
122
- callback .on_train_batch_end (self , self .get_model ())
132
+ callback .on_train_batch_end (self , self .get_model (), batch , batch_idx , dataloader_idx )
123
133
124
- def on_validation_batch_start (self ):
134
+ def on_validation_batch_start (self , batch , batch_idx , dataloader_idx ):
125
135
"""Called when the validation batch begins."""
126
136
for callback in self .callbacks :
127
- callback .on_validation_batch_start (self , self .get_model ())
137
+ callback .on_validation_batch_start (self , self .get_model (), batch , batch_idx , dataloader_idx )
128
138
129
- def on_validation_batch_end (self ):
139
+ def on_validation_batch_end (self , batch , batch_idx , dataloader_idx ):
130
140
"""Called when the validation batch ends."""
131
141
for callback in self .callbacks :
132
- callback .on_validation_batch_end (self , self .get_model ())
142
+ callback .on_validation_batch_end (self , self .get_model (), batch , batch_idx , dataloader_idx )
133
143
134
- def on_test_batch_start (self ):
144
+ def on_test_batch_start (self , batch , batch_idx , dataloader_idx ):
135
145
"""Called when the test batch begins."""
136
146
for callback in self .callbacks :
137
- callback .on_test_batch_start (self , self .get_model ())
147
+ callback .on_test_batch_start (self , self .get_model (), batch , batch_idx , dataloader_idx )
138
148
139
- def on_test_batch_end (self ):
149
+ def on_test_batch_end (self , batch , batch_idx , dataloader_idx ):
140
150
"""Called when the test batch ends."""
141
151
for callback in self .callbacks :
142
- callback .on_test_batch_end (self , self .get_model ())
152
+ callback .on_test_batch_end (self , self .get_model (), batch , batch_idx , dataloader_idx )
143
153
144
154
def on_validation_start (self ):
145
155
"""Called when the validation loop begins."""
0 commit comments