Skip to content

Commit bb36623

Browse files
tarepanBorda
authored andcommitted
Add non-existing resume_from_checkpoint acceptance for auto-resubmit (#4402)
* Add empty resume_from_checkpoint acceptance #4366 * Fix general error catch with focused file check * Add fsspec HTTP extras Add fsspec's HTTPFileSystem support through http extras. pl has supported remote http file (e.g. #2925), so this commit do not add new functionality. * Fix potential too much logging in DDP * Add PR changelog * Add well-written argument explanation Co-authored-by: Adrian Wälchli <[email protected]> * Fix DDP-compatible restore logging Notify from where the states are restored. This feature temporally deleted as a result of PR review. With succeeding review, added with DDP compatibility. * Fix utility import pathes * Refactor load step commentaries * Refactor hpc ckpt suffix acquisition * Refactor restore/hpc_load match * Refactor hpc load trial * Refactor checkpoint dir check * Refactor unneeded function nest * Refactor nested If * Refactor duplicated cache clear * Refactor attempt flow with if/elif * Fix pip8 * Refactor hook commentary Co-authored-by: chaton <[email protected]> * Fix pep8 * Refactor hpc load checkpoint path acquisition * Fix pip8 * Fix typo Co-authored-by: Adrian Wälchli <[email protected]> * Fix typo Co-authored-by: Adrian Wälchli <[email protected]> * Fix doc Co-authored-by: Adrian Wälchli <[email protected]> * Refactor None Union type with Optional * Fix build-doc CI failure debuged in #5329 * Fix fsspec import during build-doc #5329 * Fix test epoch Co-authored-by: Adrian Wälchli <[email protected]> * Fix test with latest test models * . Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Roger Shieh <[email protected]> (cherry picked from commit b0051e8)
1 parent cc607d5 commit bb36623

File tree

8 files changed

+43
-10
lines changed

8 files changed

+43
-10
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6060

6161
### Added
6262

63+
- Added `resume_from_checkpoint` accept non-existing file path ([#4402](https://github.com/PyTorchLightning/pytorch-lightning/pull/4402))
64+
6365

6466
### Removed
6567

docs/source/conf.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,14 @@ def setup(app):
293293
# Ignoring Third-party packages
294294
# https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule
295295
def package_list_from_file(file):
296+
"""List up package name (not containing version and extras) from a package list file
297+
"""
296298
mocked_packages = []
297299
with open(file, 'r') as fp:
298300
for ln in fp.readlines():
299-
found = [ln.index(ch) for ch in list(',=<>#') if ch in ln]
301+
# Example: `tqdm>=4.41.0` => `tqdm`
302+
# `[` is for package with extras
303+
found = [ln.index(ch) for ch in list(',=<>#[') if ch in ln]
300304
pkg = ln[:min(found)] if found else ln
301305
if pkg.rstrip():
302306
mocked_packages.append(pkg.rstrip())

environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ dependencies:
3030
- future>=0.17.1
3131
- PyYAML>=5.1
3232
- tqdm>=4.41.0
33-
- fsspec>=0.8.0
33+
- fsspec[http]>=0.8.1
3434
#- tensorboard>=2.2.0 # not needed, already included in pytorch
3535

3636
# Optional

pytorch_lightning/trainer/connectors/checkpoint_connector.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, trainer):
4242
# used to validate checkpointing logic
4343
self.has_trained = False
4444

45-
def restore_weights(self, model: LightningModule):
45+
def restore_weights(self, model: LightningModule) -> None:
4646
"""
4747
Attempt to restore a checkpoint (e.g. weights) in this priority:
4848
1. from HPC weights
@@ -72,11 +72,16 @@ def restore_weights(self, model: LightningModule):
7272
if self.trainer.on_gpu:
7373
torch.cuda.empty_cache()
7474

75-
def restore(self, checkpoint_path: str, on_gpu: bool):
75+
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
7676
"""
7777
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
7878
All restored states are listed in return value description of `dump_checkpoint`.
7979
"""
80+
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
81+
fs = get_filesystem(checkpoint_path)
82+
if not fs.exists(checkpoint_path):
83+
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
84+
return False
8085

8186
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
8287
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
@@ -93,6 +98,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
9398
# restore training state
9499
self.restore_training_state(checkpoint)
95100

101+
rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
102+
return True
103+
96104
def restore_model_state(self, model: LightningModule, checkpoint) -> None:
97105
"""
98106
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object

pytorch_lightning/trainer/trainer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ def __init__(
252252
train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
253253
you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
254254
255-
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
256-
This can be a URL. If resuming from mid-epoch checkpoint, training will start from
257-
the beginning of the next epoch.
255+
resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is
256+
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
257+
training will start from the beginning of the next epoch.
258258
259259
sync_batchnorm: Synchronize batch norm layers between process groups/whole world.
260260

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ future>=0.17.1 # required for builtins in setup.py
66
# pyyaml>=3.13
77
PyYAML>=5.1 # OmegaConf requirement >=5.1
88
tqdm>=4.41.0
9-
fsspec>=0.8.0
9+
fsspec[http]>=0.8.1
1010
tensorboard>=2.2.0

requirements/docs.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ https://github.com/PyTorchLightning/lightning_sphinx_theme/archive/master.zip#eg
1111
sphinx-autodoc-typehints
1212
sphinx-paramlinks<0.4.0
1313
sphinx-togglebutton
14-
sphinx-copybutton
14+
sphinx-copybutton

tests/models/test_restore.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import tests.base.develop_utils as tutils
2626
from pytorch_lightning import Callback, Trainer
2727
from pytorch_lightning.callbacks import ModelCheckpoint
28-
from tests.base import EvalModelTemplate, GenericEvalModelTemplate
28+
from tests.base import BoringModel, EvalModelTemplate, GenericEvalModelTemplate
2929

3030

3131
class ModelTrainerPropertyParity(Callback):
@@ -71,6 +71,25 @@ def test_model_properties_resume_from_checkpoint(enable_pl_optimizer, tmpdir):
7171
trainer.fit(model)
7272

7373

74+
def test_try_resume_from_non_existing_checkpoint(tmpdir):
75+
""" Test that trying to resume from non-existing `resume_from_checkpoint` fail without error."""
76+
model = BoringModel()
77+
checkpoint_cb = ModelCheckpoint(dirpath=tmpdir, monitor="early_stop_on", save_last=True)
78+
trainer = Trainer(
79+
default_root_dir=tmpdir,
80+
max_epochs=1,
81+
logger=False,
82+
callbacks=[checkpoint_cb],
83+
limit_train_batches=0.1,
84+
limit_val_batches=0.1,
85+
)
86+
# Generate checkpoint `last.ckpt` with BoringModel
87+
trainer.fit(model)
88+
# `True` if resume/restore successfully else `False`
89+
assert trainer.checkpoint_connector.restore(str(tmpdir / "last.ckpt"), trainer.on_gpu)
90+
assert not trainer.checkpoint_connector.restore(str(tmpdir / "last_non_existing.ckpt"), trainer.on_gpu)
91+
92+
7493
class CaptureCallbacksBeforeTraining(Callback):
7594
callbacks = []
7695

0 commit comments

Comments
 (0)