Skip to content

Commit f798cff

Browse files
authored
save last model after saving top_k when save_last=True (#2881)
* save_last should be last * changelog * seed, docs * retrigger ci * compare filenames * move constants * fix test * epoch, global step * improve test
1 parent d9d7e91 commit f798cff

File tree

4 files changed

+51
-10
lines changed

4 files changed

+51
-10
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
106106

107107
- Fixed LR finder and `hparams` compatibility ([#2821](https://github.com/PyTorchLightning/pytorch-lightning/pull/2821))
108108

109+
- Fixed `ModelCheckpoint` not saving the latest information when `save_last=True` ([#2881](https://github.com/PyTorchLightning/pytorch-lightning/pull/2881))
110+
109111
## [0.8.5] - 2020-07-09
110112

111113
### Added

pytorch_lightning/callbacks/model_checkpoint.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ class ModelCheckpoint(Callback):
9696
9797
"""
9898

99+
CHECKPOINT_NAME_LAST = "last.ckpt"
100+
CHECKPOINT_STATE_BEST_SCORE = "checkpoint_callback_best_model_score"
101+
CHECKPOINT_STATE_BEST_PATH = "checkpoint_callback_best_model_path"
102+
99103
def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
100104
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
101105
mode: str = 'auto', period: int = 1, prefix: str = ''):
@@ -302,10 +306,6 @@ def on_validation_end(self, trainer, pl_module):
302306

303307
self.epoch_last_check = epoch
304308

305-
if self.save_last:
306-
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
307-
self._save_model(filepath, trainer, pl_module)
308-
309309
filepath = self.format_checkpoint_name(epoch, metrics)
310310
version_cnt = 0
311311
while os.path.isfile(filepath):
@@ -340,6 +340,10 @@ def on_validation_end(self, trainer, pl_module):
340340
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
341341
self._save_model(filepath, trainer, pl_module)
342342

343+
if self.save_last:
344+
filepath = os.path.join(self.dirpath, self.prefix + ModelCheckpoint.CHECKPOINT_NAME_LAST)
345+
self._save_model(filepath, trainer, pl_module)
346+
343347
def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
344348
# remove kth
345349

pytorch_lightning/trainer/training_io.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
354354
if checkpoint_callbacks:
355355
# we add the official checkpoint callback to the end of the list
356356
# extra user provided callbacks will not be persisted yet
357-
checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score
358-
checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path
357+
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] = self.checkpoint_callback.best_model_score
358+
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] = self.checkpoint_callback.best_model_path
359359

360360
if early_stopping_callbacks and checkpoint_callbacks:
361361
# we add the official early stopping callback to the end of the list
@@ -436,16 +436,16 @@ def restore_training_state(self, checkpoint):
436436
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
437437

438438
if checkpoint_callbacks:
439-
if 'checkpoint_callback_best_model_score' in checkpoint:
440-
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best_model_score']
439+
if ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE in checkpoint:
440+
checkpoint_callbacks[-1].best_model_score = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE]
441441
else:
442442
# Old naming until version 0.7.6
443443
rank_zero_warn(
444444
'Loading a checkpoint created with an old version of Lightning; '
445445
'this will not be supported in the future.'
446446
)
447447
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best']
448-
checkpoint_callbacks[-1].best_model_path = checkpoint['checkpoint_callback_best_model_path']
448+
checkpoint_callbacks[-1].best_model_path = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH]
449449

450450
if early_stopping_callbacks:
451451
state = checkpoint['early_stop_callback_state_dict']

tests/callbacks/test_model_checkpoint.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
import cloudpickle
77
import pytest
8+
import torch
89

910
import tests.base.develop_utils as tutils
10-
from pytorch_lightning import Trainer
11+
from pytorch_lightning import Trainer, seed_everything
1112
from pytorch_lightning.callbacks import ModelCheckpoint
1213
from pytorch_lightning.loggers import TensorBoardLogger
1314
from tests.base import EvalModelTemplate
@@ -93,3 +94,37 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
9394
)
9495
result = trainer.fit(model)
9596
assert 1 == result
97+
98+
99+
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
100+
""" Tests that the checkpoint saved as 'last.ckpt' contains the latest information. """
101+
seed_everything(100)
102+
model = EvalModelTemplate()
103+
num_epochs = 3
104+
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
105+
trainer = Trainer(
106+
default_root_dir=tmpdir,
107+
early_stop_callback=False,
108+
checkpoint_callback=model_checkpoint,
109+
max_epochs=num_epochs,
110+
)
111+
trainer.fit(model)
112+
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
113+
path_last = str(tmpdir / ModelCheckpoint.CHECKPOINT_NAME_LAST) # last.ckpt
114+
assert path_last_epoch != path_last
115+
ckpt_last_epoch = torch.load(path_last_epoch)
116+
ckpt_last = torch.load(path_last)
117+
matching_keys = (
118+
"epoch",
119+
"global_step",
120+
ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE,
121+
ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH,
122+
)
123+
for key in matching_keys:
124+
assert ckpt_last_epoch[key] == ckpt_last[key]
125+
126+
# it is easier to load the model objects than to iterate over the raw dict of tensors
127+
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
128+
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
129+
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
130+
assert w0.eq(w1).all()

0 commit comments

Comments
 (0)