Skip to content

Commit fd1693e

Browse files
moi90BordawilliamFalconawaelchli
authored
Handle KeyboardInterrupt during training (#2134)
* Handle KeyboardInterrupt during training Fixes #2079. * chlog * Fix whitespace * Update callback_hook.py * Update base.py * Update training_loop.py * Update test_trainer.py * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <[email protected]> * Update CHANGELOG.md * on_keyboard_interrupt Co-authored-by: Jirka <[email protected]> Co-authored-by: William Falcon <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent bd3a1f7 commit fd1693e

File tree

5 files changed

+29
-5
lines changed

5 files changed

+29
-5
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
### Added
2323

24-
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
2524
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
2625
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
2726
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
@@ -35,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3534
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
3635
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
3736
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
37+
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))
3838

3939
### Changed
4040

@@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4646
- Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029))
4747
- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020))
4848
- Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166))
49+
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
4950

5051
### Deprecated
5152

pytorch_lightning/callbacks/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,6 @@ def on_test_start(self, trainer, pl_module):
8585
def on_test_end(self, trainer, pl_module):
8686
"""Called when the test ends."""
8787
pass
88+
89+
def on_keyboard_interrupt(self, trainer, pl_module):
90+
"""Called when the training is interrupted by KeyboardInterrupt."""

pytorch_lightning/trainer/callback_hook.py

+5
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,8 @@ def on_test_end(self):
100100
"""Called when the test ends."""
101101
for callback in self.callbacks:
102102
callback.on_test_end(self, self.get_model())
103+
104+
def on_keyboard_interrupt(self):
105+
"""Called when the training is interrupted by KeyboardInterrupt."""
106+
for callback in self.callbacks:
107+
callback.on_keyboard_interrupt(self, self.get_model())

pytorch_lightning/trainer/training_loop.py

+3
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ class TrainerTrainLoopMixin(ABC):
237237
checkpoint_callback: ...
238238
terminate_on_nan: bool
239239
tpu_id: int
240+
interactive_ddp_procs: ...
240241

241242
# Callback system
242243
callbacks: List[Callback]
@@ -247,6 +248,7 @@ class TrainerTrainLoopMixin(ABC):
247248
on_epoch_start: Callable
248249
on_epoch_end: Callable
249250
on_validation_end: Callable
251+
on_keyboard_interrupt: Callable
250252

251253
@abstractmethod
252254
def get_model(self) -> LightningModule:
@@ -395,6 +397,7 @@ def train(self):
395397
# user could press ctrl+c many times... only shutdown once
396398
if not self.interrupted:
397399
self.interrupted = True
400+
self.on_keyboard_interrupt()
398401

399402
for proc in self.interactive_ddp_procs:
400403
subprocess.Popen.kill(proc)

tests/trainer/test_trainer.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
import os
44
import pickle
55
import types
6+
import sys
67
from argparse import Namespace
78

89
import cloudpickle
910
import pytest
1011
import torch
1112

1213
import tests.base.utils as tutils
13-
from pytorch_lightning import Callback, LightningModule
14-
from pytorch_lightning import Trainer
14+
from pytorch_lightning import Callback, LightningModule, Trainer
1515
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
16-
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
16+
from pytorch_lightning.core.saving import (
17+
load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv)
1718
from pytorch_lightning.loggers import TensorBoardLogger
1819
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
1920
from pytorch_lightning.utilities.io import load as pl_load
@@ -660,10 +661,19 @@ def __init__(self):
660661
def on_batch_start(self, trainer, pl_module):
661662
raise KeyboardInterrupt
662663

664+
class HandleInterruptCallback(Callback):
665+
def __init__(self):
666+
super().__init__()
667+
self.exc_info = None
668+
669+
def on_keyboard_interrupt(self, trainer, pl_module):
670+
self.exc_info = sys.exc_info()
671+
663672
interrupt_callback = InterruptCallback()
673+
handle_interrupt_callback = HandleInterruptCallback()
664674

665675
trainer = Trainer(
666-
callbacks=[interrupt_callback],
676+
callbacks=[interrupt_callback, handle_interrupt_callback],
667677
max_epochs=1,
668678
val_percent_check=0.1,
669679
train_percent_check=0.2,
@@ -672,8 +682,10 @@ def on_batch_start(self, trainer, pl_module):
672682
default_root_dir=tmpdir,
673683
)
674684
assert not trainer.interrupted
685+
assert handle_interrupt_callback.exc_info is None
675686
trainer.fit(model)
676687
assert trainer.interrupted
688+
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)
677689

678690

679691
def test_gradient_clipping(tmpdir):

0 commit comments

Comments
 (0)