Skip to content

Commit 17947ab

Browse files
committed
Add INTERRUPTED state, improve tests, move state switching from callback to a trainer.
1 parent 482010c commit 17947ab

File tree

4 files changed

+103
-32
lines changed

4 files changed

+103
-32
lines changed

pytorch_lightning/trainer/states.py

+6-28
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,9 @@
1-
from enum import Enum, auto
2-
3-
from pytorch_lightning import Callback
1+
from enum import Enum
42

53

64
class TrainerState(Enum):
7-
""" State which is set to the Trainer to indicate what is being executed. """
8-
INITIALIZE = auto()
9-
RUNNING = auto()
10-
FINISHED = auto()
11-
12-
13-
class _TrainerStateSwitcher(Callback):
14-
""" Special callback used by the Trainer. This callback sets proper
15-
state to the trainer depending on what is being executed.
16-
"""
17-
18-
def on_init_start(self, trainer):
19-
trainer.state = TrainerState.INITIALIZE
20-
21-
def on_init_end(self, trainer):
22-
trainer.state = TrainerState.INITIALIZE
23-
24-
def setup(self, trainer, stage: str):
25-
trainer.state = TrainerState.RUNNING
26-
27-
def teardown(self, trainer, stage: str):
28-
trainer.state = TrainerState.FINISHED
29-
30-
def on_keyboard_interrupt(self, trainer, pl_module):
31-
trainer.state = TrainerState.FINISHED
5+
""" State which is set in the Trainer to indicate what is currently or was executed. """
6+
INITIALIZE = 'INITIALIZE'
7+
RUNNING = 'RUNNING'
8+
FINISHED = 'FINISHED'
9+
INTERRUPTED = 'INTERRUPTED'

pytorch_lightning/trainer/trainer.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
2727
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
2828
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
29-
from pytorch_lightning.trainer.states import _TrainerStateSwitcher, TrainerState
29+
from pytorch_lightning.trainer.states import TrainerState
3030
from pytorch_lightning.trainer.supporters import TensorRunningAccum
3131
from pytorch_lightning.trainer.training_io import TrainerIOMixin
3232
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
@@ -419,8 +419,6 @@ def __init__(
419419
# init callbacks
420420
self.callbacks = callbacks or []
421421

422-
self.callbacks.append(_TrainerStateSwitcher())
423-
424422
# configure early stop callback
425423
# creates a default one if none passed in
426424
early_stop_callback = self.configure_early_stopping(early_stop_callback)
@@ -914,6 +912,8 @@ def fit(
914912
# check that model is configured correctly
915913
self.check_model_configuration(model)
916914

915+
self.state = TrainerState.RUNNING
916+
917917
# callbacks
918918
self.on_fit_start()
919919
if self.is_function_implemented('on_fit_start', model):
@@ -1031,6 +1031,8 @@ def fit(
10311031
if self.is_function_implemented('teardown'):
10321032
model.teardown('fit')
10331033

1034+
if self.state != TrainerState.INTERRUPTED:
1035+
self.state = TrainerState.FINISHED
10341036
# return 1 when finished
10351037
# used for testing or when we need to know that training succeeded
10361038
return results or 1
@@ -1246,6 +1248,8 @@ def test(
12461248
if self.is_function_implemented('setup', model_ref):
12471249
model_ref.setup('test')
12481250

1251+
self.state = TrainerState.RUNNING
1252+
12491253
# if user requests the best checkpoint but we don't have it, error
12501254
if model is None and ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
12511255
raise MisconfigurationException(
@@ -1295,6 +1299,9 @@ def test(
12951299
model_ref = self.get_model()
12961300
model_ref.teardown('test')
12971301

1302+
if self.state != TrainerState.INTERRUPTED:
1303+
self.state = TrainerState.FINISHED
1304+
12981305
return results
12991306

13001307
def check_model_configuration(self, model: LightningModule):

pytorch_lightning/trainer/training_loop.py

+3
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def training_step(self, batch, batch_idx):
159159
from pytorch_lightning.callbacks import ModelCheckpoint
160160
from pytorch_lightning.core.lightning import LightningModule
161161
from pytorch_lightning.loggers import LightningLoggerBase
162+
from pytorch_lightning.trainer.states import TrainerState
162163
from pytorch_lightning.trainer.supporters import TensorRunningAccum
163164
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
164165
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -240,6 +241,7 @@ class TrainerTrainLoopMixin(ABC):
240241
terminate_on_nan: bool
241242
tpu_id: int
242243
interactive_ddp_procs: ...
244+
state: TrainerState
243245

244246
# Callback system
245247
callbacks: List[Callback]
@@ -397,6 +399,7 @@ def train(self):
397399
# user could press ctrl+c many times... only shutdown once
398400
if not self.interrupted:
399401
self.interrupted = True
402+
self.state = TrainerState.INTERRUPTED
400403
self.on_keyboard_interrupt()
401404

402405
self.run_training_teardown()

tests/trainer/test_states.py

+84-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pytorch_lightning import Trainer
1+
from pytorch_lightning import Trainer, Callback
22
from pytorch_lightning.trainer.states import TrainerState
33
from tests.base import EvalModelTemplate
44

@@ -15,6 +15,35 @@ def test_initialize_state(tmpdir):
1515
assert trainer.state == TrainerState.INITIALIZE
1616

1717

18+
def test_running_state_during_fit(tmpdir):
19+
"""
20+
Tests that state is set to RUNNING during fit
21+
"""
22+
23+
class StateSnapshotCallback(Callback):
24+
def __init__(self):
25+
super().__init__()
26+
self.trainer_state = None
27+
28+
def on_batch_start(self, trainer, pl_module):
29+
self.trainer_state = trainer.state
30+
31+
hparams = EvalModelTemplate.get_default_hparams()
32+
model = EvalModelTemplate(**hparams)
33+
34+
snapshot_callback = StateSnapshotCallback()
35+
36+
trainer = Trainer(
37+
callbacks=[snapshot_callback],
38+
default_root_dir=tmpdir,
39+
fast_dev_run=True,
40+
)
41+
42+
trainer.fit(model)
43+
44+
assert snapshot_callback.trainer_state == TrainerState.RUNNING
45+
46+
1847
def test_finished_state_after_fit(tmpdir):
1948
"""
2049
Tests that state is FINISHED after fit
@@ -32,6 +61,35 @@ def test_finished_state_after_fit(tmpdir):
3261
assert trainer.state == TrainerState.FINISHED
3362

3463

64+
def test_running_state_during_test(tmpdir):
65+
"""
66+
Tests that state is set to RUNNING during test
67+
"""
68+
69+
class StateSnapshotCallback(Callback):
70+
def __init__(self):
71+
super().__init__()
72+
self.trainer_state = None
73+
74+
def on_test_batch_start(self, trainer, pl_module):
75+
self.trainer_state = trainer.state
76+
77+
hparams = EvalModelTemplate.get_default_hparams()
78+
model = EvalModelTemplate(**hparams)
79+
80+
snapshot_callback = StateSnapshotCallback()
81+
82+
trainer = Trainer(
83+
callbacks=[snapshot_callback],
84+
default_root_dir=tmpdir,
85+
fast_dev_run=True,
86+
)
87+
88+
trainer.test(model)
89+
90+
assert snapshot_callback.trainer_state == TrainerState.RUNNING
91+
92+
3593
def test_finished_state_after_test(tmpdir):
3694
"""
3795
Tests that state is FINISHED after fit
@@ -47,3 +105,28 @@ def test_finished_state_after_test(tmpdir):
47105
trainer.test(model)
48106

49107
assert trainer.state == TrainerState.FINISHED
108+
109+
110+
def test_interrupt_state_on_keyboard_interrupt(tmpdir):
111+
"""
112+
Tests that state is set to INTERRUPTED on KeyboardInterrupt
113+
"""
114+
hparams = EvalModelTemplate.get_default_hparams()
115+
model = EvalModelTemplate(**hparams)
116+
117+
class InterruptCallback(Callback):
118+
def __init__(self):
119+
super().__init__()
120+
121+
def on_batch_start(self, trainer, pl_module):
122+
raise KeyboardInterrupt
123+
124+
trainer = Trainer(
125+
callbacks=[InterruptCallback()],
126+
default_root_dir=tmpdir,
127+
fast_dev_run=True,
128+
)
129+
130+
trainer.fit(model)
131+
132+
assert trainer.state == TrainerState.INTERRUPTED

0 commit comments

Comments
 (0)