Skip to content

Commit a03a0bf

Browse files
williamFalconatee
authored and
atee
committed
added warning when changing monitor and using results obj (Lightning-AI#3014)
* added warning when changing monitor and using results obj * added warning when changing monitor and using results obj * added warning when changing monitor and using results obj
1 parent d18da8e commit a03a0bf

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed

pytorch_lightning/callbacks/early_stopping.py

+16
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytorch_lightning import _logger as log
1515
from pytorch_lightning.callbacks.base import Callback
1616
from pytorch_lightning.utilities import rank_zero_warn
17+
import os
1718

1819
torch_inf = torch.tensor(np.Inf)
1920

@@ -72,6 +73,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
7273
self.wait_count = 0
7374
self.stopped_epoch = 0
7475
self.mode = mode
76+
self.warned_result_obj = False
7577

7678
if mode not in self.mode_dict:
7779
if self.verbose > 0:
@@ -154,12 +156,26 @@ def on_train_epoch_end(self, trainer, pl_module):
154156
if should_check_early_stop:
155157
self._run_early_stopping_check(trainer, pl_module)
156158

159+
def __warn_deprecated_monitor_key(self):
160+
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
161+
invalid_key = self.monitor not in ['val_loss', 'early_stop_on', 'val_early_step_on']
162+
if using_result_obj and not self.warned_result_obj and invalid_key:
163+
self.warned_result_obj = True
164+
m = f"""
165+
When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
166+
'monitor' key of EarlyStopping has no effect.
167+
Remove EarlyStopping(monitor='{self.monitor}) to fix')
168+
"""
169+
rank_zero_warn(m)
170+
157171
def _run_early_stopping_check(self, trainer, pl_module):
158172
logs = trainer.callback_metrics
159173

160174
if not self._validate_condition_metric(logs):
161175
return # short circuit if metric not present
162176

177+
self.__warn_deprecated_monitor_key()
178+
163179
current = logs.get(self.monitor)
164180

165181
# when in dev debugging

pytorch_lightning/callbacks/model_checkpoint.py

+16
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
138138
self.best_model_score = 0
139139
self.best_model_path = ''
140140
self.save_function = None
141+
self.warned_result_obj = False
141142

142143
torch_inf = torch.tensor(np.Inf)
143144
mode_dict = {
@@ -297,12 +298,27 @@ def on_train_start(self, trainer, pl_module):
297298
if not gfile.exists(self.dirpath):
298299
makedirs(self.dirpath)
299300

301+
def __warn_deprecated_monitor_key(self):
302+
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
303+
invalid_key = self.monitor not in ['val_loss', 'checkpoint_on']
304+
if using_result_obj and not self.warned_result_obj and invalid_key:
305+
self.warned_result_obj = True
306+
m = f"""
307+
When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
308+
'monitor' key of ModelCheckpoint has no effect.
309+
Remove ModelCheckpoint(monitor='{self.monitor}) to fix')
310+
"""
311+
rank_zero_warn(m)
312+
300313
@rank_zero_only
301314
def on_validation_end(self, trainer, pl_module):
302315
# only run on main process
303316
if trainer.global_rank != 0:
304317
return
305318

319+
# TODO: remove when dict results are deprecated
320+
self.__warn_deprecated_monitor_key()
321+
306322
metrics = trainer.callback_metrics
307323
epoch = trainer.current_epoch
308324

pytorch_lightning/core/step_result.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
from torch import Tensor
7+
import os
78

89
from pytorch_lightning.metrics.converters import _sync_ddp_if_available
910

@@ -20,6 +21,9 @@ def __init__(
2021

2122
super().__init__()
2223

24+
# temporary until dict results are deprecated
25+
os.environ['PL_USING_RESULT_OBJ'] = '1'
26+
2327
if early_stop_on is not None:
2428
self.early_stop_on = early_stop_on
2529
if checkpoint_on is not None and checkpoint_on:

tests/trainer/test_trainer_steps_result_return.py

+41
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from pytorch_lightning import Trainer
10+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
1011
from pytorch_lightning.core.step_result import TrainResult
1112
from tests.base import EvalModelTemplate
1213
from tests.base.deterministic_model import DeterministicModel
@@ -543,3 +544,43 @@ def test_result_map(tmpdir):
543544
assert 'x2' not in result
544545
assert 'y1' in result
545546
assert 'y2' in result
547+
548+
549+
def test_result_monitor_warnings(tmpdir):
550+
"""
551+
Tests that we warn when the monitor key is changed and we use Results obj
552+
"""
553+
model = EvalModelTemplate()
554+
model.test_step = None
555+
model.training_step = model.training_step_result_obj
556+
model.training_step_end = None
557+
model.training_epoch_end = None
558+
model.validation_step = model.validation_step_result_obj
559+
model.validation_step_end = None
560+
model.validation_epoch_end = None
561+
model.test_dataloader = None
562+
563+
trainer = Trainer(
564+
default_root_dir=tmpdir,
565+
max_epochs=2,
566+
early_stop_callback=True,
567+
row_log_interval=2,
568+
limit_train_batches=2,
569+
weights_summary=None,
570+
checkpoint_callback=ModelCheckpoint(monitor='not_val_loss')
571+
)
572+
573+
with pytest.warns(UserWarning, match='key of ModelCheckpoint has no effect'):
574+
trainer.fit(model)
575+
576+
trainer = Trainer(
577+
default_root_dir=tmpdir,
578+
max_epochs=2,
579+
row_log_interval=2,
580+
limit_train_batches=2,
581+
weights_summary=None,
582+
early_stop_callback=EarlyStopping(monitor='not_val_loss')
583+
)
584+
585+
with pytest.warns(UserWarning, match='key of EarlyStopping has no effec'):
586+
trainer.fit(model)

0 commit comments

Comments
 (0)