Skip to content

Commit 50881c0

Browse files
Check early stopping metric in the beginning of the training (#542)
* Early stopping fix * Update trainer.py * Don't force validation sanity check * fix tests * update * Added early_stopping check_metrics * Updated docs * Update docs * Do not call early stopping when validation is disabled Co-authored-by: William Falcon <[email protected]>
1 parent 588ad83 commit 50881c0

File tree

6 files changed

+65
-22
lines changed

6 files changed

+65
-22
lines changed

pytorch_lightning/callbacks/pt_callbacks.py

+30-15
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,23 @@ class EarlyStopping(Callback):
7171
Stop training when a monitored quantity has stopped improving.
7272
7373
Args:
74-
monitor (str): quantity to be monitored.
74+
monitor (str): quantity to be monitored. Default: ``'val_loss'``.
7575
min_delta (float): minimum change in the monitored quantity
7676
to qualify as an improvement, i.e. an absolute
77-
change of less than min_delta, will count as no
78-
improvement.
77+
change of less than `min_delta`, will count as no
78+
improvement. Default: ``0``.
7979
patience (int): number of epochs with no improvement
80-
after which training will be stopped.
81-
verbose (bool): verbosity mode.
80+
after which training will be stopped. Default: ``0``.
81+
verbose (bool): verbosity mode. Default: ``0``.
8282
mode (str): one of {auto, min, max}. In `min` mode,
8383
training will stop when the quantity
8484
monitored has stopped decreasing; in `max`
8585
mode it will stop when the quantity
8686
monitored has stopped increasing; in `auto`
8787
mode, the direction is automatically inferred
88-
from the name of the monitored quantity.
88+
from the name of the monitored quantity. Default: ``'auto'``.
89+
strict (bool): whether to crash the training if `monitor` is
90+
not found in the metrics. Default: ``True``.
8991
9092
Example::
9193
@@ -97,18 +99,20 @@ class EarlyStopping(Callback):
9799
"""
98100

99101
def __init__(self, monitor='val_loss',
100-
min_delta=0.0, patience=0, verbose=0, mode='auto'):
102+
min_delta=0.0, patience=0, verbose=0, mode='auto', strict=True):
101103
super(EarlyStopping, self).__init__()
102104

103105
self.monitor = monitor
104106
self.patience = patience
105107
self.verbose = verbose
108+
self.strict = strict
106109
self.min_delta = min_delta
107110
self.wait = 0
108111
self.stopped_epoch = 0
109112

110113
if mode not in ['auto', 'min', 'max']:
111-
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
114+
if self.verbose > 0:
115+
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
112116
mode = 'auto'
113117

114118
if mode == 'min':
@@ -128,23 +132,34 @@ def __init__(self, monitor='val_loss',
128132

129133
self.on_train_begin()
130134

135+
def check_metrics(self, logs):
136+
monitor_val = logs.get(self.monitor)
137+
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
138+
f' which is not available. Available metrics are:'
139+
f' `{"`, `".join(list(logs.keys()))}`')
140+
141+
if monitor_val is None:
142+
if self.strict:
143+
raise RuntimeError(error_msg)
144+
elif self.verbose > 0:
145+
warnings.warn(error_msg, RuntimeWarning)
146+
147+
return False
148+
149+
return True
150+
131151
def on_train_begin(self, logs=None):
132152
# Allow instances to be re-used
133153
self.wait = 0
134154
self.stopped_epoch = 0
135155
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
136156

137157
def on_epoch_end(self, epoch, logs=None):
138-
current = logs.get(self.monitor)
139158
stop_training = False
140-
if current is None:
141-
warnings.warn(
142-
f'Early stopping conditioned on metric `{self.monitor}`'
143-
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
144-
RuntimeWarning)
145-
stop_training = True
159+
if not self.check_metrics(logs):
146160
return stop_training
147161

162+
current = logs.get(self.monitor)
148163
if self.monitor_op(current - self.min_delta, self.best):
149164
self.best = current
150165
self.wait = 0

pytorch_lightning/trainer/callback_config.py

+10
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,20 @@ def configure_early_stopping(self, early_stop_callback, logger):
5555
self.early_stop_callback = EarlyStopping(
5656
monitor='val_loss',
5757
patience=3,
58+
strict=True,
5859
verbose=True,
5960
mode='min'
6061
)
6162
self.enable_early_stop = True
63+
elif early_stop_callback is None:
64+
self.early_stop_callback = EarlyStopping(
65+
monitor='val_loss',
66+
patience=3,
67+
strict=False,
68+
verbose=False,
69+
mode='min'
70+
)
71+
self.enable_early_stop = True
6272
elif not early_stop_callback:
6373
self.early_stop_callback = None
6474
self.enable_early_stop = False

pytorch_lightning/trainer/trainer.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self,
5353
logger=True,
5454
checkpoint_callback=True,
55-
early_stop_callback=True,
55+
early_stop_callback=None,
5656
default_save_path=None,
5757
gradient_clip_val=0,
5858
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
@@ -121,15 +121,22 @@ def __init__(
121121
)
122122
123123
trainer = Trainer(checkpoint_callback=checkpoint_callback)
124-
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping
124+
early_stop_callback (:class:`.EarlyStopping`): Callback for early stopping. If
125+
set to ``True``, then the default callback monitoring ``'val_loss'`` is created.
126+
Will raise an error if ``'val_loss'`` is not found.
127+
If set to ``False``, then early stopping will be disabled.
128+
If set to ``None``, then the default callback monitoring ``'val_loss'`` is created.
129+
If ``'val_loss'`` is not found will work as if early stopping is disabled.
130+
Default: ``None``.
125131
Example::
126132
from pytorch_lightning.callbacks import EarlyStopping
127133
128134
# default used by the Trainer
129135
early_stop_callback = EarlyStopping(
130136
monitor='val_loss',
131137
patience=3,
132-
verbose=True,
138+
strict=False,
139+
verbose=False,
133140
mode='min'
134141
)
135142
@@ -809,12 +816,17 @@ def run_pretrain_routine(self, model):
809816
# dummy validation progress bar
810817
self.val_progress_bar = tqdm.tqdm(disable=True)
811818

812-
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
819+
eval_results = self.evaluate(model, self.get_val_dataloaders(),
820+
self.num_sanity_val_steps, False)
821+
_, _, _, callback_metrics, _ = self.process_output(eval_results)
813822

814823
# close progress bars
815824
self.main_progress_bar.close()
816825
self.val_progress_bar.close()
817826

827+
if self.enable_early_stop:
828+
self.early_stop_callback.check_metrics(callback_metrics)
829+
818830
# init progress bar
819831
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
820832
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',

pytorch_lightning/trainer/training_loop.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,8 @@ def train(self):
346346

347347
# early stopping
348348
met_min_epochs = epoch >= self.min_epochs - 1
349-
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
349+
if (self.enable_early_stop and not self.disable_validation and
350+
(met_min_epochs or self.fast_dev_run)):
350351
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
351352
logs=self.callback_metrics)
352353
# stop training
@@ -401,6 +402,9 @@ def run_training_epoch(self):
401402
if self.fast_dev_run or should_check_val:
402403
self.run_evaluation(test=self.testing)
403404

405+
if self.enable_early_stop:
406+
self.early_stop_callback.check_metrics(self.callback_metrics)
407+
404408
# when logs should be saved
405409
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
406410
if should_save_log or self.fast_dev_run:

tests/test_cpu_models.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase):
140140
val_percent_check=0.2,
141141
test_percent_check=0.2,
142142
checkpoint_callback=checkpoint,
143-
logger=logger
143+
logger=logger,
144+
early_stop_callback=False
144145
)
145146

146147
# fit model
@@ -318,6 +319,7 @@ def train_dataloader(self):
318319
truncated_bptt_steps=truncated_bptt_steps,
319320
val_percent_check=0,
320321
weights_summary=None,
322+
early_stop_callback=False
321323
)
322324

323325
hparams = tutils.get_hparams()

tests/test_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ class CurrentTestModel(
392392
default_save_path=tmpdir,
393393
max_epochs=1,
394394
val_percent_check=0.1,
395-
train_percent_check=0.2,
395+
train_percent_check=0.2
396396
)
397397

398398
# fit model

0 commit comments

Comments
 (0)