Skip to content

Commit e9846dd

Browse files
Add tracking of basic states in Trainer [wip - to-be-merged after v0.9] (#2541)
* Add initial tracking of states in Trainer. * Add INTERRUPTED state, improve tests, move state switching from callback to a trainer. * Move part of a trainer state switching to a decorator. * Add documentation. * Fix docs, rename state enum, restore state to previous on exit if None, add tests for decorator only. * Fix callback typing. Co-authored-by: William Falcon <[email protected]>
1 parent 13fe0a4 commit e9846dd

File tree

4 files changed

+266
-0
lines changed

4 files changed

+266
-0
lines changed

pytorch_lightning/trainer/states.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from enum import Enum
2+
from functools import wraps
3+
from typing import Callable, Optional
4+
5+
import pytorch_lightning
6+
7+
8+
class TrainerState(Enum):
9+
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
10+
to indicate what is currently or was executed. """
11+
INITIALIZING = 'INITIALIZING'
12+
RUNNING = 'RUNNING'
13+
FINISHED = 'FINISHED'
14+
INTERRUPTED = 'INTERRUPTED'
15+
16+
17+
def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[TrainerState] = None) -> Callable:
18+
""" Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods
19+
which changes state to `entering` before the function execution and `exiting`
20+
after the function is executed. If `None` is passed to `entering`, the state is not changed.
21+
If `None` is passed to `exiting`, the state is restored to the state before function execution.
22+
If `INTERRUPTED` state is set inside a run function, the state remains `INTERRUPTED`.
23+
"""
24+
25+
def wrapper(fn) -> Callable:
26+
@wraps(fn)
27+
def wrapped_fn(self, *args, **kwargs):
28+
if not isinstance(self, pytorch_lightning.Trainer):
29+
return fn(self, *args, **kwargs)
30+
31+
state_before = self.state
32+
if entering is not None:
33+
self.state = entering
34+
result = fn(self, *args, **kwargs)
35+
36+
# The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
37+
# we retain INTERRUPTED state
38+
if self.state == TrainerState.INTERRUPTED:
39+
return result
40+
41+
if exiting is not None:
42+
self.state = exiting
43+
else:
44+
self.state = state_before
45+
return result
46+
47+
return wrapped_fn
48+
49+
return wrapper

pytorch_lightning/trainer/trainer.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
4646
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
4747
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
48+
from pytorch_lightning.trainer.states import TrainerState, trainer_state
4849
from pytorch_lightning.trainer.supporters import TensorRunningAccum
4950
from pytorch_lightning.trainer.training_io import TrainerIOMixin
5051
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
@@ -395,6 +396,7 @@ def __init__(
395396
self.interrupted = False
396397
self.should_stop = False
397398
self.running_sanity_check = False
399+
self.state = TrainerState.INITIALIZING
398400

399401
self._default_root_dir = default_root_dir or os.getcwd()
400402
self._weights_save_path = weights_save_path or self._default_root_dir
@@ -888,6 +890,7 @@ def weights_save_path(self) -> str:
888890
# -----------------------------
889891
# MODEL TRAINING
890892
# -----------------------------
893+
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
891894
def fit(
892895
self,
893896
model: LightningModule,
@@ -1240,6 +1243,7 @@ def _run_sanity_check(self, ref_model, model):
12401243
self.on_sanity_check_end()
12411244
self.running_sanity_check = False
12421245

1246+
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
12431247
def test(
12441248
self,
12451249
model: Optional[LightningModule] = None,

pytorch_lightning/trainer/training_loop.py

+3
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def training_step(self, batch, batch_idx):
174174
from pytorch_lightning.core.lightning import LightningModule
175175
from pytorch_lightning.core.step_result import EvalResult, Result
176176
from pytorch_lightning.loggers import LightningLoggerBase
177+
from pytorch_lightning.trainer.states import TrainerState
177178
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
178179
from pytorch_lightning.utilities import rank_zero_warn, AMPType
179180
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -253,6 +254,7 @@ class TrainerTrainLoopMixin(ABC):
253254
terminate_on_nan: bool
254255
tpu_id: int
255256
interactive_ddp_procs: ...
257+
state: TrainerState
256258
amp_type: AMPType
257259
on_tpu: bool
258260

@@ -418,6 +420,7 @@ def train(self):
418420
# user could press ctrl+c many times... only shutdown once
419421
if not self.interrupted:
420422
self.interrupted = True
423+
self.state = TrainerState.INTERRUPTED
421424
self.on_keyboard_interrupt()
422425

423426
self.run_training_teardown()

tests/trainer/test_states.py

+210
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import pytest
2+
3+
from pytorch_lightning import Trainer, Callback
4+
from pytorch_lightning.trainer.states import TrainerState, trainer_state
5+
from tests.base import EvalModelTemplate
6+
7+
8+
class StateSnapshotCallback(Callback):
9+
""" Allows to shapshot the state inside a particular trainer method. """
10+
11+
def __init__(self, snapshot_method: str):
12+
super().__init__()
13+
assert snapshot_method in ['on_batch_start', 'on_test_batch_start']
14+
self.snapshot_method = snapshot_method
15+
self.trainer_state = None
16+
17+
def on_batch_start(self, trainer, pl_module):
18+
if self.snapshot_method == 'on_batch_start':
19+
self.trainer_state = trainer.state
20+
21+
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
22+
if self.snapshot_method == 'on_test_batch_start':
23+
self.trainer_state = trainer.state
24+
25+
26+
def test_state_decorator_nothing_passed(tmpdir):
27+
""" Test that state is not changed if nothing is passed to a decorator"""
28+
29+
@trainer_state()
30+
def test_method(self):
31+
return self.state
32+
33+
trainer = Trainer(default_root_dir=tmpdir)
34+
trainer.state = TrainerState.INITIALIZING
35+
36+
snapshot_state = test_method(trainer)
37+
38+
assert snapshot_state == TrainerState.INITIALIZING
39+
assert trainer.state == TrainerState.INITIALIZING
40+
41+
42+
def test_state_decorator_entering_only(tmpdir):
43+
""" Tests that state is set to entering inside a run function and restored to the previous value after. """
44+
45+
@trainer_state(entering=TrainerState.RUNNING)
46+
def test_method(self):
47+
return self.state
48+
49+
trainer = Trainer(default_root_dir=tmpdir)
50+
trainer.state = TrainerState.INITIALIZING
51+
52+
snapshot_state = test_method(trainer)
53+
54+
assert snapshot_state == TrainerState.RUNNING
55+
assert trainer.state == TrainerState.INITIALIZING
56+
57+
58+
def test_state_decorator_exiting_only(tmpdir):
59+
""" Tests that state is not changed inside a run function and set to `exiting` after. """
60+
61+
@trainer_state(exiting=TrainerState.FINISHED)
62+
def test_method(self):
63+
return self.state
64+
65+
trainer = Trainer(default_root_dir=tmpdir)
66+
trainer.state = TrainerState.INITIALIZING
67+
68+
snapshot_state = test_method(trainer)
69+
70+
assert snapshot_state == TrainerState.INITIALIZING
71+
assert trainer.state == TrainerState.FINISHED
72+
73+
74+
def test_state_decorator_entering_and_exiting(tmpdir):
75+
""" Tests that state is set to `entering` inside a run function and set ot `exiting` after. """
76+
77+
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
78+
def test_method(self):
79+
return self.state
80+
81+
trainer = Trainer(default_root_dir=tmpdir)
82+
trainer.state = TrainerState.INITIALIZING
83+
84+
snapshot_state = test_method(trainer)
85+
86+
assert snapshot_state == TrainerState.RUNNING
87+
assert trainer.state == TrainerState.FINISHED
88+
89+
90+
def test_state_decorator_interrupt(tmpdir):
91+
""" Tests that state remains `INTERRUPTED` is its set in run function. """
92+
93+
@trainer_state(exiting=TrainerState.FINISHED)
94+
def test_method(self):
95+
self.state = TrainerState.INTERRUPTED
96+
97+
trainer = Trainer(default_root_dir=tmpdir)
98+
trainer.state = TrainerState.INITIALIZING
99+
100+
test_method(trainer)
101+
assert trainer.state == TrainerState.INTERRUPTED
102+
103+
104+
def test_initialize_state(tmpdir):
105+
""" Tests that state is INITIALIZE after Trainer creation """
106+
trainer = Trainer(default_root_dir=tmpdir)
107+
assert trainer.state == TrainerState.INITIALIZING
108+
109+
110+
@pytest.mark.parametrize("extra_params", [
111+
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
112+
pytest.param(dict(max_steps=1), id='Single-Step'),
113+
])
114+
def test_running_state_during_fit(tmpdir, extra_params):
115+
""" Tests that state is set to RUNNING during fit """
116+
117+
hparams = EvalModelTemplate.get_default_hparams()
118+
model = EvalModelTemplate(**hparams)
119+
120+
snapshot_callback = StateSnapshotCallback(snapshot_method='on_batch_start')
121+
122+
trainer = Trainer(
123+
callbacks=[snapshot_callback],
124+
default_root_dir=tmpdir,
125+
**extra_params
126+
)
127+
128+
trainer.fit(model)
129+
130+
assert snapshot_callback.trainer_state == TrainerState.RUNNING
131+
132+
133+
@pytest.mark.parametrize("extra_params", [
134+
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
135+
pytest.param(dict(max_steps=1), id='Single-Step'),
136+
])
137+
def test_finished_state_after_fit(tmpdir, extra_params):
138+
""" Tests that state is FINISHED after fit """
139+
hparams = EvalModelTemplate.get_default_hparams()
140+
model = EvalModelTemplate(**hparams)
141+
142+
trainer = Trainer(
143+
default_root_dir=tmpdir,
144+
**extra_params
145+
)
146+
147+
trainer.fit(model)
148+
149+
assert trainer.state == TrainerState.FINISHED
150+
151+
152+
def test_running_state_during_test(tmpdir):
153+
""" Tests that state is set to RUNNING during test """
154+
155+
hparams = EvalModelTemplate.get_default_hparams()
156+
model = EvalModelTemplate(**hparams)
157+
158+
snapshot_callback = StateSnapshotCallback(snapshot_method='on_test_batch_start')
159+
160+
trainer = Trainer(
161+
callbacks=[snapshot_callback],
162+
default_root_dir=tmpdir,
163+
fast_dev_run=True,
164+
)
165+
166+
trainer.test(model)
167+
168+
assert snapshot_callback.trainer_state == TrainerState.RUNNING
169+
170+
171+
def test_finished_state_after_test(tmpdir):
172+
""" Tests that state is FINISHED after fit """
173+
hparams = EvalModelTemplate.get_default_hparams()
174+
model = EvalModelTemplate(**hparams)
175+
176+
trainer = Trainer(
177+
default_root_dir=tmpdir,
178+
fast_dev_run=True,
179+
)
180+
181+
trainer.test(model)
182+
183+
assert trainer.state == TrainerState.FINISHED
184+
185+
186+
@pytest.mark.parametrize("extra_params", [
187+
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
188+
pytest.param(dict(max_steps=1), id='Single-Step'),
189+
])
190+
def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params):
191+
""" Tests that state is set to INTERRUPTED on KeyboardInterrupt """
192+
hparams = EvalModelTemplate.get_default_hparams()
193+
model = EvalModelTemplate(**hparams)
194+
195+
class InterruptCallback(Callback):
196+
def __init__(self):
197+
super().__init__()
198+
199+
def on_batch_start(self, trainer, pl_module):
200+
raise KeyboardInterrupt
201+
202+
trainer = Trainer(
203+
callbacks=[InterruptCallback()],
204+
default_root_dir=tmpdir,
205+
**extra_params
206+
)
207+
208+
trainer.fit(model)
209+
210+
assert trainer.state == TrainerState.INTERRUPTED

0 commit comments

Comments
 (0)