Skip to content

Commit 2f01c03

Browse files
schwobrwilliamFalcon
schwobr
authored andcommitted
Additional hooks (#598)
* Renamed `on_sanity_check_start` to `on_train_start` and added `on_train_end` to `ModelHooks` * changed tests to use `on_train_start` instead of `on_sanity_check_start`
1 parent 1051c18 commit 2f01c03

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

pytorch_lightning/core/hooks.py

+16
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,26 @@ class ModelHooks(torch.nn.Module):
2828
def on_sanity_check_start(self):
2929
"""
3030
Called before starting evaluate
31+
.. warning:: will be deprecated.
3132
:return:
3233
"""
3334
pass
3435

36+
def on_train_start(self):
37+
"""Called at the beginning of training before sanity check
38+
:return:
39+
"""
40+
# do something at the start of training
41+
pass
42+
43+
def on_train_end(self):
44+
"""
45+
Called at the end of training before logger experiment is closed
46+
:return:
47+
"""
48+
# do something at the end of training
49+
pass
50+
3551
def on_batch_start(self, batch):
3652
"""Called in the training loop before anything happens for that batch.
3753

pytorch_lightning/trainer/trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ def run_pretrain_routine(self, model):
489489
# run tiny validation (if validation defined)
490490
# to make sure program won't crash during val
491491
ref_model.on_sanity_check_start()
492+
ref_model.on_train_start()
492493
if self.get_val_dataloaders() is not None and self.num_sanity_val_steps > 0:
493494
# init progress bars for validation sanity check
494495
pbar = tqdm.tqdm(desc='Validation sanity check',

pytorch_lightning/trainer/training_loop.py

+2
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ def train(self):
331331

332332
self.main_progress_bar.close()
333333

334+
model.on_train_end()
335+
334336
if self.logger is not None:
335337
self.logger.finalize("success")
336338

tests/test_restore_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def assert_good_acc():
247247

248248
# new model
249249
model = LightningTestModel(hparams)
250-
model.on_sanity_check_start = assert_good_acc
250+
model.on_train_start = assert_good_acc
251251

252252
# fit new model which should load hpc weights
253253
new_trainer.fit(model)
@@ -311,7 +311,7 @@ def assert_good_acc():
311311
for dataloader in trainer.get_val_dataloaders():
312312
tutils.run_prediction(dataloader, trainer.model)
313313

314-
model.on_sanity_check_start = assert_good_acc
314+
model.on_train_start = assert_good_acc
315315

316316
# by calling fit again, we trigger training, loading weights from the cluster
317317
# and our hook to predict using current model before any more weight updates

0 commit comments

Comments
 (0)