Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle KeyboardInterrupt during training #2134

Merged
merged 10 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
Expand All @@ -35,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))

### Changed

Expand All @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029))
- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020))
- Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166))
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))

### Deprecated

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ def on_test_start(self, trainer, pl_module):
def on_test_end(self, trainer, pl_module):
"""Called when the test ends."""
pass

def on_keyboard_interrupt(self, trainer, pl_module):
"""Called when the training is interrupted by KeyboardInterrupt."""
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,8 @@ def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end(self, self.get_model())

def on_keyboard_interrupt(self):
"""Called when the training is interrupted by KeyboardInterrupt."""
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.get_model())
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class TrainerTrainLoopMixin(ABC):
checkpoint_callback: ...
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...

# Callback system
callbacks: List[Callback]
Expand All @@ -247,6 +248,7 @@ class TrainerTrainLoopMixin(ABC):
on_epoch_start: Callable
on_epoch_end: Callable
on_validation_end: Callable
on_keyboard_interrupt: Callable

@abstractmethod
def get_model(self) -> LightningModule:
Expand Down Expand Up @@ -395,6 +397,7 @@ def train(self):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.on_keyboard_interrupt()

for proc in self.interactive_ddp_procs:
subprocess.Popen.kill(proc)
Expand Down
20 changes: 16 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import os
import pickle
import types
import sys
from argparse import Namespace

import cloudpickle
import pytest
import torch

import tests.base.utils as tutils
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.core.saving import (
load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.io import load as pl_load
Expand Down Expand Up @@ -660,10 +661,19 @@ def __init__(self):
def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

class HandleInterruptCallback(Callback):
def __init__(self):
super().__init__()
self.exc_info = None

def on_keyboard_interrupt(self, trainer, pl_module):
self.exc_info = sys.exc_info()

interrupt_callback = InterruptCallback()
handle_interrupt_callback = HandleInterruptCallback()

trainer = Trainer(
callbacks=[interrupt_callback],
callbacks=[interrupt_callback, handle_interrupt_callback],
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2,
Expand All @@ -672,8 +682,10 @@ def on_batch_start(self, trainer, pl_module):
default_root_dir=tmpdir,
)
assert not trainer.interrupted
assert handle_interrupt_callback.exc_info is None
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)


def test_gradient_clipping(tmpdir):
Expand Down