Skip to content

Commit f82d7fe

Browse files
updated hooks (#2850)
* modified hooks * modified hooks * modified hooks * modified hooks * modified hooks * modified hooks * modified hooks * modified hooks * modified hooks
1 parent b39f479 commit f82d7fe

14 files changed

+152
-97
lines changed

pytorch_lightning/callbacks/base.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ class Callback(abc.ABC):
1414
Abstract base class used to build new callbacks.
1515
"""
1616

17-
def setup(self, trainer, stage: str):
17+
def setup(self, trainer, pl_module, stage: str):
1818
"""Called when fit or test begins"""
1919
pass
2020

21-
def teardown(self, trainer, stage: str):
21+
def teardown(self, trainer, pl_module, stage: str):
2222
"""Called when fit or test ends"""
2323
pass
2424

@@ -30,11 +30,11 @@ def on_init_end(self, trainer):
3030
"""Called when the trainer initialization ends, model has not yet been set."""
3131
pass
3232

33-
def on_fit_start(self, trainer):
33+
def on_fit_start(self, trainer, pl_module):
3434
"""Called when fit begins"""
3535
pass
3636

37-
def on_fit_end(self, trainer):
37+
def on_fit_end(self, trainer, pl_module):
3838
"""Called when fit ends"""
3939
pass
4040

@@ -46,11 +46,11 @@ def on_sanity_check_end(self, trainer, pl_module):
4646
"""Called when the validation sanity check ends."""
4747
pass
4848

49-
def on_train_batch_start(self, trainer, pl_module):
49+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
5050
"""Called when the validation batch begins."""
5151
pass
5252

53-
def on_train_batch_end(self, trainer, pl_module):
53+
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
5454
"""Called when the validation batch ends."""
5555
pass
5656

@@ -90,19 +90,19 @@ def on_batch_start(self, trainer, pl_module):
9090
"""Called when the training batch begins."""
9191
pass
9292

93-
def on_validation_batch_start(self, trainer, pl_module):
93+
def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
9494
"""Called when the validation batch begins."""
9595
pass
9696

97-
def on_validation_batch_end(self, trainer, pl_module):
97+
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
9898
"""Called when the validation batch ends."""
9999
pass
100100

101-
def on_test_batch_start(self, trainer, pl_module):
101+
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
102102
"""Called when the test batch begins."""
103103
pass
104104

105-
def on_test_batch_end(self, trainer, pl_module):
105+
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
106106
"""Called when the test batch ends."""
107107
pass
108108

@@ -118,6 +118,14 @@ def on_train_end(self, trainer, pl_module):
118118
"""Called when the train ends."""
119119
pass
120120

121+
def on_pretrain_routine_start(self, trainer, pl_module):
122+
"""Called when the pretrain routine begins."""
123+
pass
124+
125+
def on_pretrain_routine_end(self, trainer, pl_module):
126+
"""Called when the pretrain routine ends."""
127+
pass
128+
121129
def on_validation_start(self, trainer, pl_module):
122130
"""Called when the validation loop begins."""
123131
pass

pytorch_lightning/callbacks/lr_logger.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def on_train_start(self, trainer, pl_module):
6464
# Initialize for storing values
6565
self.lrs = {name: [] for name in names}
6666

67-
def on_train_batch_start(self, trainer, pl_module):
67+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
6868
latest_stat = self._extract_lr(trainer, 'step')
6969
if trainer.logger and latest_stat:
7070
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

pytorch_lightning/callbacks/progress.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,19 @@ def on_train_start(self, trainer, pl_module):
138138
def on_epoch_start(self, trainer, pl_module):
139139
self._train_batch_idx = 0
140140

141-
def on_train_batch_end(self, trainer, pl_module):
141+
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
142142
self._train_batch_idx += 1
143143

144144
def on_validation_start(self, trainer, pl_module):
145145
self._val_batch_idx = 0
146146

147-
def on_validation_batch_end(self, trainer, pl_module):
147+
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
148148
self._val_batch_idx += 1
149149

150150
def on_test_start(self, trainer, pl_module):
151151
self._test_batch_idx = 0
152152

153-
def on_test_batch_end(self, trainer, pl_module):
153+
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
154154
self._test_batch_idx += 1
155155

156156

@@ -318,8 +318,8 @@ def on_epoch_start(self, trainer, pl_module):
318318
self.main_progress_bar.reset(convert_inf(total_batches))
319319
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch + 1}')
320320

321-
def on_train_batch_end(self, trainer, pl_module):
322-
super().on_train_batch_end(trainer, pl_module)
321+
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
322+
super().on_train_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
323323
if self.is_enabled and self.train_batch_idx % self.refresh_rate == 0:
324324
self.main_progress_bar.update(self.refresh_rate)
325325
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
@@ -329,8 +329,8 @@ def on_validation_start(self, trainer, pl_module):
329329
self.val_progress_bar = self.init_validation_tqdm()
330330
self.val_progress_bar.total = convert_inf(self.total_val_batches)
331331

332-
def on_validation_batch_end(self, trainer, pl_module):
333-
super().on_validation_batch_end(trainer, pl_module)
332+
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
333+
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
334334
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
335335
self.val_progress_bar.update(self.refresh_rate)
336336
self.main_progress_bar.update(self.refresh_rate)
@@ -349,8 +349,8 @@ def on_test_start(self, trainer, pl_module):
349349
self.test_progress_bar = self.init_test_tqdm()
350350
self.test_progress_bar.total = convert_inf(self.total_test_batches)
351351

352-
def on_test_batch_end(self, trainer, pl_module):
353-
super().on_test_batch_end(trainer, pl_module)
352+
def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
353+
super().on_test_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
354354
if self.is_enabled and self.test_batch_idx % self.refresh_rate == 0:
355355
self.test_progress_bar.update(self.refresh_rate)
356356

pytorch_lightning/core/hooks.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -77,20 +77,51 @@ def on_train_end(self) -> None:
7777
"""
7878
# do something at the end of training
7979

80-
def on_train_batch_start(self, batch: Any) -> None:
80+
def on_pretrain_routine_start(self) -> None:
81+
"""
82+
Called at the beginning of the pretrain routine (between fit and train start).
83+
84+
- fit
85+
- pretrain_routine start
86+
- pretrain_routine end
87+
- training_start
88+
89+
"""
90+
# do something at the start of the pretrain routine
91+
92+
def on_pretrain_routine_end(self) -> None:
93+
"""
94+
Called at the end of the pretrain routine (between fit and train start).
95+
96+
- fit
97+
- pretrain_routine start
98+
- pretrain_routine end
99+
- training_start
100+
101+
"""
102+
# do something at the end of the pretrain routine
103+
104+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
81105
"""
82106
Called in the training loop before anything happens for that batch.
83107
84108
If you return -1 here, you will skip training for the rest of the current epoch.
85109
86110
Args:
87111
batch: The batched data as it is returned by the training DataLoader.
112+
batch_idx: the index of the batch
113+
dataloader_idx: the index of the dataloader
88114
"""
89115
# do something when the batch starts
90116

91-
def on_train_batch_end(self) -> None:
117+
def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
92118
"""
93119
Called in the training loop after the batch.
120+
121+
Args:
122+
batch: The batched data as it is returned by the training DataLoader.
123+
batch_idx: the index of the batch
124+
dataloader_idx: the index of the dataloader
94125
"""
95126
# do something when the batch end
96127

pytorch_lightning/trainer/callback_hook.py

+27-17
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ class TrainerCallbackHookMixin(ABC):
1414
def setup(self, stage: str):
1515
"""Called in the beginning of fit and test"""
1616
for callback in self.callbacks:
17-
callback.setup(self, stage)
17+
callback.setup(self, self.get_model(), stage)
1818

1919
def teardown(self, stage: str):
2020
"""Called at the end of fit and test"""
2121
for callback in self.callbacks:
22-
callback.teardown(self, stage)
22+
callback.teardown(self, self.get_model(), stage)
2323

2424
def on_init_start(self):
2525
"""Called when the trainer initialization begins, model has not yet been set."""
@@ -31,15 +31,15 @@ def on_init_end(self):
3131
for callback in self.callbacks:
3232
callback.on_init_end(self)
3333

34-
def on_fit_start(self):
34+
def on_fit_start(self, model):
3535
"""Called when the trainer initialization begins, model has not yet been set."""
3636
for callback in self.callbacks:
37-
callback.on_fit_start(self)
37+
callback.on_fit_start(self, model)
3838

3939
def on_fit_end(self):
4040
"""Called when the trainer initialization begins, model has not yet been set."""
4141
for callback in self.callbacks:
42-
callback.on_fit_end(self)
42+
callback.on_fit_end(self, self.get_model())
4343

4444
def on_sanity_check_start(self):
4545
"""Called when the validation sanity check starts."""
@@ -101,6 +101,16 @@ def on_train_end(self):
101101
for callback in self.callbacks:
102102
callback.on_train_end(self, self.get_model())
103103

104+
def on_pretrain_routine_start(self, model):
105+
"""Called when the train begins."""
106+
for callback in self.callbacks:
107+
callback.on_pretrain_routine_start(self, model)
108+
109+
def on_pretrain_routine_end(self, model):
110+
"""Called when the train ends."""
111+
for callback in self.callbacks:
112+
callback.on_pretrain_routine_end(self, model)
113+
104114
def on_batch_start(self):
105115
"""Called when the training batch begins."""
106116
for callback in self.callbacks:
@@ -111,35 +121,35 @@ def on_batch_end(self):
111121
for callback in self.callbacks:
112122
callback.on_batch_end(self, self.get_model())
113123

114-
def on_train_batch_start(self):
124+
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
115125
"""Called when the training batch begins."""
116126
for callback in self.callbacks:
117-
callback.on_train_batch_start(self, self.get_model())
127+
callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
118128

119-
def on_train_batch_end(self):
129+
def on_train_batch_end(self, batch, batch_idx, dataloader_idx):
120130
"""Called when the training batch ends."""
121131
for callback in self.callbacks:
122-
callback.on_train_batch_end(self, self.get_model())
132+
callback.on_train_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)
123133

124-
def on_validation_batch_start(self):
134+
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
125135
"""Called when the validation batch begins."""
126136
for callback in self.callbacks:
127-
callback.on_validation_batch_start(self, self.get_model())
137+
callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
128138

129-
def on_validation_batch_end(self):
139+
def on_validation_batch_end(self, batch, batch_idx, dataloader_idx):
130140
"""Called when the validation batch ends."""
131141
for callback in self.callbacks:
132-
callback.on_validation_batch_end(self, self.get_model())
142+
callback.on_validation_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)
133143

134-
def on_test_batch_start(self):
144+
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
135145
"""Called when the test batch begins."""
136146
for callback in self.callbacks:
137-
callback.on_test_batch_start(self, self.get_model())
147+
callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
138148

139-
def on_test_batch_end(self):
149+
def on_test_batch_end(self, batch, batch_idx, dataloader_idx):
140150
"""Called when the test batch ends."""
141151
for callback in self.callbacks:
142-
callback.on_test_batch_end(self, self.get_model())
152+
callback.on_test_batch_end(self, self.get_model(), batch, batch_idx, dataloader_idx)
143153

144154
def on_validation_start(self):
145155
"""Called when the validation loop begins."""

pytorch_lightning/trainer/evaluation_loop.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ def _evaluate(
312312

313313
# callbacks
314314
if test_mode:
315-
self.on_test_batch_start()
315+
self.on_test_batch_start(batch, batch_idx, dataloader_idx)
316316
else:
317-
self.on_validation_batch_start()
317+
self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
318318

319319
# -----------------
320320
# RUN EVALUATION STEP
@@ -336,13 +336,13 @@ def _evaluate(
336336
model_ref = self.get_model()
337337
with self.profiler.profile('test_step_end'):
338338
output = model_ref.test_step_end(output)
339-
self.on_test_batch_end()
339+
self.on_test_batch_end(batch, batch_idx, dataloader_idx)
340340
else:
341341
if self.is_overridden('validation_step_end'):
342342
model_ref = self.get_model()
343343
with self.profiler.profile('validation_step_end'):
344344
output = model_ref.validation_step_end(output)
345-
self.on_validation_batch_end()
345+
self.on_validation_batch_end(batch, batch_idx, dataloader_idx)
346346

347347
# track outputs for collation
348348
if output is not None:

pytorch_lightning/trainer/lr_finder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def on_batch_start(self, trainer, pl_module):
384384

385385
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
386386

387-
def on_train_batch_end(self, trainer, pl_module):
387+
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
388388
""" Called when the training batch ends, logs the calculated loss """
389389
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
390390
return

pytorch_lightning/trainer/trainer.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ def fit(
956956
self.config_validator.verify_loop_configurations(model)
957957

958958
# callbacks
959-
self.on_fit_start()
959+
self.on_fit_start(model)
960960
if self.is_function_implemented('on_fit_start', model):
961961
model.on_fit_start()
962962

@@ -1053,13 +1053,12 @@ def fit(
10531053
self.accelerator_backend.setup(model)
10541054
results = self.accelerator_backend.train(model)
10551055

1056-
# callbacks
1056+
# on fit end callback
10571057
self.on_fit_end()
1058-
1059-
# model hooks
10601058
if self.is_function_implemented('on_fit_end'):
10611059
model.on_fit_end()
10621060

1061+
# teardown callback
10631062
self.teardown('fit')
10641063
if self.is_function_implemented('teardown'):
10651064
model.teardown('fit')
@@ -1154,6 +1153,11 @@ def run_pretrain_routine(self, model: LightningModule):
11541153
# register auto-resubmit when on SLURM
11551154
self.register_slurm_signal_handlers()
11561155

1156+
# on pretrain routine start
1157+
self.on_pretrain_routine_start(ref_model)
1158+
if self.is_function_implemented('on_pretrain_routine_start'):
1159+
ref_model.on_pretrain_routine_start()
1160+
11571161
# print model summary
11581162
if self.is_global_zero and self.weights_summary is not None and not self.testing:
11591163
if self.weights_summary in ModelSummary.MODES:
@@ -1196,6 +1200,11 @@ def run_pretrain_routine(self, model: LightningModule):
11961200
with torch.cuda.device(f'cuda:{self.root_gpu}'):
11971201
torch.cuda.empty_cache()
11981202

1203+
# on pretrain routine end
1204+
self.on_pretrain_routine_end(ref_model)
1205+
if self.is_function_implemented('on_pretrain_routine_end'):
1206+
ref_model.on_pretrain_routine_end()
1207+
11991208
# CORE TRAINING LOOP
12001209
self.train()
12011210

0 commit comments

Comments
 (0)