Skip to content

Add step index in checkpoint name #3807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Nov 2, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
Borda committed Oct 5, 2020
commit 07bb754260d6962a2bc00ff7b88d916a130b0ab6
33 changes: 19 additions & 14 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -204,7 +204,7 @@ def save_checkpoint(self, trainer, pl_module):
monitor_candidates = self._monitor_candidates(trainer)

# ie: path/val_loss=0.5.ckpt
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
filepath = self._get_metric_interpolated_filepath_name(epoch, global_step, monitor_candidates)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
@@ -213,7 +213,7 @@ def save_checkpoint(self, trainer, pl_module):
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, global_step, filepath)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, global_step, monitor_candidates, filepath)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
@@ -323,6 +323,7 @@ def _format_checkpoint_name(
cls,
filename: Optional[str],
epoch: int,
step: int,
metrics: Dict[str, Any],
prefix: str = "",
) -> str:
@@ -332,7 +333,7 @@ def _format_checkpoint_name(
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if len(groups) >= 0:
metrics["epoch"] = epoch
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)
@@ -342,28 +343,28 @@ def _format_checkpoint_name(
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])

def format_checkpoint_name(
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.
Example::
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
"""
filename = self._format_checkpoint_name(
self.filename, epoch, metrics, prefix=self.prefix
self.filename, epoch, step, metrics, prefix=self.prefix
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
@@ -440,12 +441,12 @@ def _validate_monitor_key(self, trainer):
)
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
def _get_metric_interpolated_filepath_name(self, epoch: int, step: int, ckpt_name_metrics: Dict[str, Any]):
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(
epoch, ckpt_name_metrics, ver=version_cnt
epoch, step, ckpt_name_metrics, ver=version_cnt
)
# this epoch called before
version_cnt += 1
@@ -457,7 +458,7 @@ def _monitor_candidates(self, trainer):
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
return ckpt_name_metrics

def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
@@ -467,7 +468,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")

5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
@@ -239,9 +239,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
# depre warning
if eval_results is not None and user_reduced:
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
m = f'The {step} should not return anything as of 9.1.' \
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
self.warning_cache.warn(m)
)

if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)
70 changes: 40 additions & 30 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -53,7 +53,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k):
path_yaml = os.path.join(tmpdir, 'best_k_models.yaml')
checkpoint.to_yaml(path_yaml)
d = yaml.full_load(open(path_yaml, 'r'))
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
best_k = {k: v for k, v in checkpoint.best_k_models.items()}
assert d == best_k


@@ -124,7 +124,9 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
"""Test to ensure that the model callback saves the checkpoints only once in distributed mode."""
model = EvalModelTemplate()
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1)
model_checkpoint = ModelCheckpointTestInvocations(
filepath=tmpdir, monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1
)
trainer = Trainer(
distributed_backend="ddp_cpu",
num_processes=2,
@@ -139,50 +141,51 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):

def test_model_checkpoint_format_checkpoint_name(tmpdir):
# empty filename:
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {})
assert ckpt_name == 'epoch=3'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test')
assert ckpt_name == 'test-epoch=3'
ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {})
assert ckpt_name == 'epoch=3-step=2'
ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test')
assert ckpt_name == 'test-epoch=3-step=2'
# no groups case:
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test')
ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test')
assert ckpt_name == 'test-ckpt'
# no prefix
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03})
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03})
assert ckpt_name == 'epoch=003-acc=0.03'
# prefix
char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@'
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test')
ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test')
assert ckpt_name == 'test@epoch=3,acc=0.03000'
ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org
# no filepath set
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, {})
assert ckpt_name == 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, {})
assert ckpt_name == 'epoch=5.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, 4, {})
assert ckpt_name == 'epoch=3-step=4.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, 4, {})
assert ckpt_name == 'epoch=5-step=4.ckpt'
# CWD
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 4, {})
assert Path(ckpt_name) == Path('.') / 'epoch=3-step=4.ckpt'
# dir does not exist so it is used as filename
filepath = tmpdir / 'dir'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {})
assert ckpt_name == tmpdir / 'test-dir.ckpt'
# now, dir exists
os.mkdir(filepath)
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {})
assert ckpt_name == filepath / 'test-epoch=3.ckpt'
ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {})
assert ckpt_name == filepath / 'test-epoch=3-step=4.ckpt'
# with ver
ckpt_name = ModelCheckpoint(monitor='early_stop_on',
filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3)
filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, 4, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'


def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
model = EvalModelTemplate()
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}'
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=-1, save_last=True)
model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir / '{step}', save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
@@ -191,10 +194,12 @@ def test_model_checkpoint_save_last(tmpdir):
logger=False,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {})
last_filename = model_checkpoint._format_checkpoint_name(
ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {}
)
last_filename = last_filename + '.ckpt'
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename])
assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [19, 29, 30]] + [last_filename])
ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'


@@ -229,12 +234,13 @@ def test_none_monitor_save_last(tmpdir):

def test_model_checkpoint_none_monitor(tmpdir):
""" Test that it is possible to save all checkpoints when monitor=None. """
seed_everything()
model = EvalModelTemplate()
model.validation_step = model.validation_step_no_monitor
model.validation_epoch_end = model.validation_epoch_end_no_monitor

epochs = 2
checkpoint_callback = ModelCheckpoint(monitor=None, filepath=tmpdir, save_top_k=-1)
checkpoint_callback = ModelCheckpoint(monitor=None, filepath=tmpdir / '{step}', save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
@@ -246,28 +252,29 @@ def test_model_checkpoint_none_monitor(tmpdir):

# these should not be set if monitor is None
assert checkpoint_callback.monitor is None
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1.ckpt'
assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=20.ckpt'
assert checkpoint_callback.best_model_score == 0
assert checkpoint_callback.best_k_models == {}
assert checkpoint_callback.kth_best_model_path == ''

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)]
expected = [f'step={i}.ckpt' for i in [9, 19, 20]]
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("period", list(range(4)))
def test_model_checkpoint_period(tmpdir, period):
model = EvalModelTemplate()
epochs = 5
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, period=period)
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', save_top_k=-1, period=period)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
logger=False,
)
trainer.fit(model)
@@ -304,13 +311,14 @@ def test_model_checkpoint_topk_all(tmpdir):
seed_everything(1000)
epochs = 2
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="early_stop_on", save_top_k=-1)
checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', monitor="early_stop_on", save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
max_epochs=epochs,
logger=False,
val_check_interval=1.0,
)
trainer.fit(model)
assert checkpoint_callback.best_model_path == tmpdir / "epoch=1.ckpt"
@@ -364,12 +372,12 @@ def test_default_checkpoint_behavior(tmpdir):

assert len(results) == 1
assert results[0]['test_acc'] >= 0.80
assert len(trainer.dev_debugger.checkpoint_callback_history) == 3
assert len(trainer.dev_debugger.checkpoint_callback_history) == 4

# make sure the checkpoint we saved has the metric in the name
ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints'))
assert len(ckpts) == 1
assert ckpts[0] == 'epoch=2.ckpt'
assert ckpts[0] == 'epoch=2-step=15.ckpt'


def test_ckpt_metric_names_results(tmpdir):
@@ -426,19 +434,21 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
monitor='early_stop_on', filepath=tmpdir, save_top_k=num_epochs, save_last=True
monitor='early_stop_on', filepath=tmpdir / '{epoch}', save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
val_check_interval=1.0,
)
trainer.fit(model)

path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
path_last = str(tmpdir / "last.ckpt")
assert path_last == model_checkpoint.last_model_path
assert os.path.isfile(path_last_epoch)

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)