Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added warning when changing monitor and using results obj #3014

Merged
merged 3 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
import os

torch_inf = torch.tensor(np.Inf)

Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False

if mode not in self.mode_dict:
if self.verbose > 0:
Expand Down Expand Up @@ -154,12 +156,26 @@ def on_train_epoch_end(self, trainer, pl_module):
if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
invalid_key = self.monitor not in ['val_loss', 'early_stop_on', 'val_early_step_on']
if using_result_obj and not self.warned_result_obj and invalid_key:
self.warned_result_obj = True
m = f"""
When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
'monitor' key of EarlyStopping has no effect.
Remove EarlyStopping(monitor='{self.monitor}) to fix')
"""
rank_zero_warn(m)

def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics

if not self._validate_condition_metric(logs):
return # short circuit if metric not present

self.__warn_deprecated_monitor_key()

current = logs.get(self.monitor)

# when in dev debugging
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
self.best_model_score = 0
self.best_model_path = ''
self.save_function = None
self.warned_result_obj = False

torch_inf = torch.tensor(np.Inf)
mode_dict = {
Expand Down Expand Up @@ -297,12 +298,27 @@ def on_train_start(self, trainer, pl_module):
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
invalid_key = self.monitor not in ['val_loss', 'checkpoint_on']
if using_result_obj and not self.warned_result_obj and invalid_key:
self.warned_result_obj = True
m = f"""
When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the
'monitor' key of ModelCheckpoint has no effect.
Remove ModelCheckpoint(monitor='{self.monitor}) to fix')
"""
rank_zero_warn(m)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
return

# TODO: remove when dict results are deprecated
self.__warn_deprecated_monitor_key()

metrics = trainer.callback_metrics
epoch = trainer.current_epoch

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torch import Tensor
import os

from pytorch_lightning.metrics.converters import _sync_ddp_if_available

Expand All @@ -20,6 +21,9 @@ def __init__(

super().__init__()

# temporary until dict results are deprecated
os.environ['PL_USING_RESULT_OBJ'] = '1'

if early_stop_on is not None:
self.early_stop_on = early_stop_on
if checkpoint_on is not None and checkpoint_on:
Expand Down
41 changes: 41 additions & 0 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.step_result import TrainResult
from tests.base import EvalModelTemplate
from tests.base.deterministic_model import DeterministicModel
Expand Down Expand Up @@ -543,3 +544,43 @@ def test_result_map(tmpdir):
assert 'x2' not in result
assert 'y1' in result
assert 'y2' in result


def test_result_monitor_warnings(tmpdir):
"""
Tests that we warn when the monitor key is changed and we use Results obj
"""
model = EvalModelTemplate()
model.test_step = None
model.training_step = model.training_step_result_obj
model.training_step_end = None
model.training_epoch_end = None
model.validation_step = model.validation_step_result_obj
model.validation_step_end = None
model.validation_epoch_end = None
model.test_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=2,
weights_summary=None,
checkpoint_callback=ModelCheckpoint(monitor='not_val_loss')
)

with pytest.warns(UserWarning, match='key of ModelCheckpoint has no effect'):
trainer.fit(model)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
row_log_interval=2,
limit_train_batches=2,
weights_summary=None,
early_stop_callback=EarlyStopping(monitor='not_val_loss')
)

with pytest.warns(UserWarning, match='key of EarlyStopping has no effec'):
trainer.fit(model)