Skip to content

Commit 8518663

Browse files
authored
Attach version_ to checkpoint path only if version is int (#1748)
1 parent 0cb58fb commit 8518663

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).
1616

1717
### Changed
18-
18+
1919
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
2020

2121
- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
@@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838

3939
- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676))
4040

41+
- Fixed a bug in Trainer that prepended the checkpoint path with `version_` when it shouldn't ([#1748](https://github.com/PyTorchLightning/pytorch-lightning/pull/1748))
42+
4143
## [0.7.5] - 2020-04-27
4244

4345
### Changed

pytorch_lightning/trainer/callback_config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def configure_checkpoint_callback(self):
4949
if self.weights_save_path is not None:
5050
save_dir = self.weights_save_path
5151

52+
version = self.logger.version if isinstance(
53+
self.logger.version, str) else f'version_{self.logger.version}'
5254
ckpt_path = os.path.join(
5355
save_dir,
5456
self.logger.name,
55-
f'version_{self.logger.version}',
57+
version,
5658
"checkpoints"
5759
)
5860
else:

tests/callbacks/test_callbacks.py

+24
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from pytorch_lightning import Callback
44
from pytorch_lightning import Trainer, LightningModule
55
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
6+
from pytorch_lightning.loggers import TensorBoardLogger
67
from tests.base import EvalModelTemplate
8+
from pathlib import Path
79

810

911
def test_trainer_callback_system(tmpdir):
@@ -258,6 +260,28 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
258260
assert trainer.ckpt_path != trainer.default_root_dir
259261

260262

263+
@pytest.mark.parametrize(
264+
'logger_version,expected',
265+
[(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
266+
)
267+
def test_model_checkpoint_path(tmpdir, logger_version, expected):
268+
"""Test that "version_" prefix is only added when logger's version is an integer"""
269+
tutils.reset_seed()
270+
model = EvalModelTemplate(tutils.get_default_hparams())
271+
logger = TensorBoardLogger(str(tmpdir), version=logger_version)
272+
273+
trainer = Trainer(
274+
default_root_dir=tmpdir,
275+
overfit_pct=0.2,
276+
max_epochs=5,
277+
logger=logger
278+
)
279+
trainer.fit(model)
280+
281+
ckpt_version = Path(trainer.ckpt_path).parent.name
282+
assert ckpt_version == expected
283+
284+
261285
def test_lr_logger_single_lr(tmpdir):
262286
""" Test that learning rates are extracted and logged for single lr scheduler"""
263287
tutils.reset_seed()

0 commit comments

Comments
 (0)