Skip to content

Commit 41e71af

Browse files
yukw777Borda
authored andcommitted
Allow loading checkpoints from urls (#1667)
* allow loading checkpoints from urls * tmpdir_server fixture * test cases for loading checkpoints from url * dir => root_dir * default map_location to None * test case for resume_from_checkpoint * changelog * doc update * monkeypatch TORCH_HOME to avoid caching * Use a threading server with random ports so that it is easier to clean up * test fixes * pep8 fix * ThreadingHTTPServer support in 3.6 * pep8 fix * fix changelog * separate tests for urls * typo Co-authored-by: Peter Yu <[email protected]> * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 06cd849)
1 parent 768579d commit 41e71af

File tree

7 files changed

+81
-13
lines changed

7 files changed

+81
-13
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
3434
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
3535
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
36+
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
3637

3738
### Changed
3839

pytorch_lightning/core/saving.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from pytorch_lightning import _logger as log
1212
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
13+
from pytorch_lightning.utilities.io import load as pl_load
1314

1415
PRIMITIVE_TYPES = (bool, int, float, str)
1516
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
@@ -52,10 +53,10 @@ def load_from_checkpoint(
5253
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
5354
it stores the arguments passed to `__init__` in the checkpoint under `module_arguments`
5455
55-
Any arguments specified through \*args and \*\*kwargs will override args stored in `module_arguments`.
56+
Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`.
5657
5758
Args:
58-
checkpoint_path: Path to checkpoint.
59+
checkpoint_path: Path to checkpoint. This can also be a URL.
5960
args: Any positional args needed to init the model.
6061
map_location:
6162
If your checkpoint saved a GPU model and you now load on CPUs
@@ -131,9 +132,9 @@ def load_from_checkpoint(
131132
y_hat = pretrained_model(x)
132133
"""
133134
if map_location is not None:
134-
checkpoint = torch.load(checkpoint_path, map_location=map_location)
135+
checkpoint = pl_load(checkpoint_path, map_location=map_location)
135136
else:
136-
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
137+
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
137138

138139
# add the hparams from csv file to checkpoint
139140
if tags_csv is not None:
@@ -162,7 +163,6 @@ def load_from_checkpoint(
162163

163164
@classmethod
164165
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
165-
166166
# pass in the values we saved automatically
167167
if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
168168
# todo add some back compatibility

pytorch_lightning/trainer/trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def __init__(
279279
truncated_bptt_steps: Truncated back prop breaks performs backprop every k steps of
280280
281281
resume_from_checkpoint: To resume training from a specific checkpoint pass in the path here.
282+
This can be a URL.
282283
283284
profiler: To profile individual steps during training and assist in
284285

pytorch_lightning/trainer/training_io.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
LightningDataParallel,
103103
)
104104
from pytorch_lightning.utilities import rank_zero_warn, parsing
105+
from pytorch_lightning.utilities.io import load as pl_load
105106

106107
try:
107108
import torch_xla
@@ -287,7 +288,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
287288
# checkpoint = torch.load(checkpoint_path)
288289
# else:
289290
# load on CPU first
290-
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
291+
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
291292

292293
# load model state
293294
model = self.get_model()

pytorch_lightning/utilities/io.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
from urllib.parse import urlparse
4+
5+
6+
def load(path_or_url: str, map_location=None):
7+
parsed = urlparse(path_or_url)
8+
if parsed.scheme == '':
9+
# local file
10+
return torch.load(path_or_url, map_location=map_location)
11+
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)

tests/conftest.py

+39-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from functools import wraps
1+
from functools import wraps, partial
2+
from http.server import SimpleHTTPRequestHandler
23

4+
import sys
35
import pytest
6+
import threading
47
import torch.multiprocessing as mp
58

69

@@ -17,3 +20,38 @@ def pytest_pyfunc_call(pyfuncitem):
1720

1821
mp.spawn(wraps, (testfunction, testargs))
1922
return True
23+
24+
25+
@pytest.fixture
26+
def tmpdir_server(tmpdir):
27+
if sys.version_info >= (3, 7):
28+
Handler = partial(SimpleHTTPRequestHandler, directory=str(tmpdir))
29+
from http.server import ThreadingHTTPServer
30+
else:
31+
# unfortunately SimpleHTTPRequestHandler doesn't accept the directory arg in python3.6
32+
# so we have to hack it like this
33+
import os
34+
35+
class Handler(SimpleHTTPRequestHandler):
36+
def translate_path(self, path):
37+
# get the path from cwd
38+
path = super().translate_path(path)
39+
# get the relative path
40+
relpath = os.path.relpath(path, os.getcwd())
41+
# return the full path from root_dir
42+
return os.path.join(str(tmpdir), relpath)
43+
44+
# ThreadingHTTPServer was added in 3.7, so we need to define it ourselves
45+
from socketserver import ThreadingMixIn
46+
from http.server import HTTPServer
47+
48+
class ThreadingHTTPServer(ThreadingMixIn, HTTPServer):
49+
daemon_threads = True
50+
51+
with ThreadingHTTPServer(('', 0), Handler) as server:
52+
server_thread = threading.Thread(target=server.serve_forever)
53+
# Exit the server thread when the main thread terminates
54+
server_thread.daemon = True
55+
server_thread.start()
56+
yield server.server_address
57+
server.shutdown()

tests/trainer/test_trainer.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
1717
from pytorch_lightning.loggers import TensorBoardLogger
1818
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
19+
from pytorch_lightning.utilities.io import load as pl_load
1920
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2021
from tests.base import EvalModelTemplate
2122

2223

23-
def test_no_val_module(tmpdir):
24+
@pytest.mark.parametrize('url_ckpt', [True, False])
25+
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
2426
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
27+
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
28+
monkeypatch.setenv('TORCH_HOME', tmpdir)
2529

2630
model = EvalModelTemplate()
2731

@@ -49,15 +53,19 @@ def test_no_val_module(tmpdir):
4953
# load new model
5054
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
5155
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
56+
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path
5257
model_2 = EvalModelTemplate.load_from_checkpoint(
53-
checkpoint_path=new_weights_path,
58+
checkpoint_path=ckpt_path,
5459
hparams_file=hparams_path
5560
)
5661
model_2.eval()
5762

5863

59-
def test_no_val_end_module(tmpdir):
64+
@pytest.mark.parametrize('url_ckpt', [True, False])
65+
def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
6066
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
67+
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
68+
monkeypatch.setenv('TORCH_HOME', tmpdir)
6169

6270
model = EvalModelTemplate()
6371

@@ -82,8 +90,9 @@ def test_no_val_end_module(tmpdir):
8290
# load new model
8391
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
8492
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
93+
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' if url_ckpt else new_weights_path
8594
model_2 = EvalModelTemplate.load_from_checkpoint(
86-
checkpoint_path=new_weights_path,
95+
checkpoint_path=ckpt_path,
8796
hparams_file=hparams_path
8897
)
8998
model_2.eval()
@@ -320,8 +329,11 @@ def test_model_freeze_unfreeze():
320329
model.unfreeze()
321330

322331

323-
def test_resume_from_checkpoint_epoch_restored(tmpdir):
332+
@pytest.mark.parametrize('url_ckpt', [True, False])
333+
def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
324334
"""Verify resuming from checkpoint runs the right number of epochs"""
335+
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
336+
monkeypatch.setenv('TORCH_HOME', tmpdir)
325337

326338
hparams = EvalModelTemplate.get_default_hparams()
327339

@@ -373,10 +385,14 @@ def increment_on_load_checkpoint(self, _):
373385

374386
# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
375387
checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, '*.ckpt')))
388+
if url_ckpt:
389+
# transform local paths into url checkpoints
390+
ip, port = tmpdir_server
391+
checkpoints = [f'http://{ip}:{port}/' + os.path.basename(check) for check in checkpoints]
376392

377393
for check in checkpoints:
378394
next_model = _new_model()
379-
state = torch.load(check)
395+
state = pl_load(check)
380396

381397
# Resume training
382398
trainer_options['max_epochs'] = 2

0 commit comments

Comments
 (0)