Skip to content

Commit 613536a

Browse files
williamFalconjamesjjcondonBorda
authored andcommitted
model_checkpoint to save all models (Lightning-AI#1359)
* model_checkpoint to save all models * changelog * rise if Co-authored-by: jamesjjcondon <[email protected]> Co-authored-by: J. Borovec <[email protected]>
1 parent 08132b3 commit 613536a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5656

5757
### Fixed
5858

59-
- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).
59+
- Fixed `model_checkpoint` when saving all models ([#1359](https://github.com/PyTorchLightning/pytorch-lightning/pull/1359))
60+
- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147))
6061
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
6162
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
6263
- Fixed a bug that created an extra dataloader with active `reload_dataloaders_every_epoch` ([#1181](https://github.com/PyTorchLightning/pytorch-lightning/issues/1181)

pytorch_lightning/callbacks/model_checkpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal
8282
save_top_k: int = 1, save_weights_only: bool = False,
8383
mode: str = 'auto', period: int = 1, prefix: str = ''):
8484
super().__init__()
85-
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
85+
if save_top_k > 0 and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
8686
warnings.warn(
8787
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
8888
"All files in this directory will be deleted when a checkpoint is saved!"
@@ -219,7 +219,7 @@ def on_validation_end(self, trainer, pl_module):
219219

220220
def _do_check_save(self, filepath, current, epoch):
221221
# remove kth
222-
if len(self.best_k_models) == self.save_top_k:
222+
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
223223
delpath = self.kth_best_model
224224
self.best_k_models.pop(self.kth_best_model)
225225
self._del_model(delpath)

0 commit comments

Comments
 (0)