|
3 | 3 | from pytorch_lightning import Callback
|
4 | 4 | from pytorch_lightning import Trainer, LightningModule
|
5 | 5 | from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
|
| 6 | +from pytorch_lightning.loggers import TensorBoardLogger |
6 | 7 | from tests.base import EvalModelTemplate
|
| 8 | +from pathlib import Path |
7 | 9 |
|
8 | 10 |
|
9 | 11 | def test_trainer_callback_system(tmpdir):
|
@@ -258,6 +260,28 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
258 | 260 | assert trainer.ckpt_path != trainer.default_root_dir
|
259 | 261 |
|
260 | 262 |
|
| 263 | +@pytest.mark.parametrize( |
| 264 | + 'logger_version,expected', |
| 265 | + [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], |
| 266 | +) |
| 267 | +def test_model_checkpoint_path(tmpdir, logger_version, expected): |
| 268 | + """Test that "version_" prefix is only added when logger's version is an integer""" |
| 269 | + tutils.reset_seed() |
| 270 | + model = EvalModelTemplate(tutils.get_default_hparams()) |
| 271 | + logger = TensorBoardLogger(str(tmpdir), version=logger_version) |
| 272 | + |
| 273 | + trainer = Trainer( |
| 274 | + default_root_dir=tmpdir, |
| 275 | + overfit_pct=0.2, |
| 276 | + max_epochs=5, |
| 277 | + logger=logger |
| 278 | + ) |
| 279 | + trainer.fit(model) |
| 280 | + |
| 281 | + ckpt_version = Path(trainer.ckpt_path).parent.name |
| 282 | + assert ckpt_version == expected |
| 283 | + |
| 284 | + |
261 | 285 | def test_lr_logger_single_lr(tmpdir):
|
262 | 286 | """ Test that learning rates are extracted and logged for single lr scheduler"""
|
263 | 287 | tutils.reset_seed()
|
|
0 commit comments