Skip to content

Commit d0e04dc

Browse files
shijianjianBordajustusschockrohitgr7awaelchli
authored and
atee
committed
Added strict=False for load_from_checkpoint (Lightning-AI#2819)
* Added strict=False and hparams_file accepcts dict * Apply suggestions from code review Co-authored-by: Justus Schock <[email protected]> * Type check fix * Added tests * Linting & test fix * Removed redundant code & test * Added strict=False and hparams_file accepcts dict * Apply suggestions from code review Co-authored-by: Justus Schock <[email protected]> * Type check fix * Added tests * Linting & test fix * Removed redundant code & test * Apply suggestions from code review * tests * tests * chlog * Update tests/models/test_restore.py Co-authored-by: Rohit Gupta <[email protected]> * update test comments * Added docstring for the strict attribute * Added supplementary tests * Update saving.py * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * pep8, removed extra func Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Justus Schock <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: ananyahjha93 <[email protected]>
1 parent d259394 commit d0e04dc

File tree

4 files changed

+168
-3
lines changed

4 files changed

+168
-3
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4848

4949
- Added GPU Usage Logger ([#2932](https://github.com/PyTorchLightning/pytorch-lightning/pull/2932))
5050

51+
- Added `strict=False` and `hparams_file` accepts dict for `load_from_checkpoint` ([#2819](https://github.com/PyTorchLightning/pytorch-lightning/pull/2819))
52+
5153
### Changed
5254

5355
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

pytorch_lightning/core/saving.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def load_from_checkpoint(
3939
*args,
4040
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
4141
hparams_file: Optional[str] = None,
42+
strict: bool = True,
4243
**kwargs
4344
):
4445
r"""
@@ -71,6 +72,8 @@ def load_from_checkpoint(
7172
If your model's `hparams` argument is :class:`~argparse.Namespace`
7273
and .yaml file has hierarchical structure, you need to refactor your model to treat
7374
`hparams` as :class:`~dict`.
75+
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
76+
returned by this module's state dict. Default: `True`.
7477
hparam_overrides: A dictionary with keys to override in the hparams
7578
kwargs: Any keyword args needed to init the model.
7679
@@ -133,11 +136,11 @@ def load_from_checkpoint(
133136
# override the hparams with values that were passed in
134137
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
135138

136-
model = cls._load_model_state(checkpoint, *args, **kwargs)
139+
model = cls._load_model_state(checkpoint, strict=strict, *args, **kwargs)
137140
return model
138141

139142
@classmethod
140-
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
143+
def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, *cls_args, **cls_kwargs):
141144
cls_spec = inspect.getfullargspec(cls.__init__)
142145
cls_init_args_name = inspect.signature(cls).parameters.keys()
143146
# pass in the values we saved automatically
@@ -172,7 +175,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
172175

173176
model = cls(*cls_args, **cls_kwargs)
174177
# load the state_dict on the model automatically
175-
model.load_state_dict(checkpoint['state_dict'])
178+
model.load_state_dict(checkpoint['state_dict'], strict=strict)
176179

177180
# give model a chance to load something
178181
model.on_load_checkpoint(checkpoint)

tests/models/test_restore.py

+100
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging as log
33
import os
44
import pickle
5+
import functools
56

67
import cloudpickle
78
import pytest
@@ -319,6 +320,105 @@ def test_model_saving_loading(tmpdir):
319320
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1
320321

321322

323+
@pytest.mark.parametrize('url_ckpt', [True, False])
324+
def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
325+
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
326+
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
327+
monkeypatch.setenv('TORCH_HOME', tmpdir)
328+
329+
model = EvalModelTemplate()
330+
# Extra layer
331+
model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim)
332+
333+
# logger file to get meta
334+
logger = tutils.get_default_logger(tmpdir)
335+
336+
# fit model
337+
trainer = Trainer(
338+
default_root_dir=tmpdir,
339+
max_epochs=1,
340+
logger=logger,
341+
checkpoint_callback=ModelCheckpoint(tmpdir),
342+
)
343+
result = trainer.fit(model)
344+
345+
# traning complete
346+
assert result == 1
347+
348+
# save model
349+
new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
350+
trainer.save_checkpoint(new_weights_path)
351+
352+
# load new model
353+
hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml')
354+
hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}'
355+
ckpt_path = hparams_url if url_ckpt else new_weights_path
356+
357+
EvalModelTemplate.load_from_checkpoint(
358+
checkpoint_path=ckpt_path,
359+
hparams_file=hparams_path,
360+
strict=False,
361+
)
362+
363+
with pytest.raises(RuntimeError, match=r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
364+
EvalModelTemplate.load_from_checkpoint(
365+
checkpoint_path=ckpt_path,
366+
hparams_file=hparams_path,
367+
strict=True,
368+
)
369+
370+
371+
@pytest.mark.parametrize('url_ckpt', [True, False])
372+
def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
373+
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
374+
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
375+
monkeypatch.setenv('TORCH_HOME', tmpdir)
376+
377+
model = EvalModelTemplate()
378+
379+
# logger file to get meta
380+
logger = tutils.get_default_logger(tmpdir)
381+
382+
# fit model
383+
trainer = Trainer(
384+
default_root_dir=tmpdir,
385+
max_epochs=1,
386+
logger=logger,
387+
checkpoint_callback=ModelCheckpoint(tmpdir),
388+
)
389+
result = trainer.fit(model)
390+
391+
# traning complete
392+
assert result == 1
393+
394+
# save model
395+
new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
396+
trainer.save_checkpoint(new_weights_path)
397+
398+
# load new model
399+
hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml')
400+
hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}'
401+
ckpt_path = hparams_url if url_ckpt else new_weights_path
402+
403+
class CurrentModel(EvalModelTemplate):
404+
def __init__(self):
405+
super().__init__()
406+
self.c_d3 = torch.nn.Linear(7, 7)
407+
408+
CurrentModel.load_from_checkpoint(
409+
checkpoint_path=ckpt_path,
410+
hparams_file=hparams_path,
411+
strict=False,
412+
)
413+
414+
with pytest.raises(RuntimeError, match=r'Missing key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
415+
CurrentModel.load_from_checkpoint(
416+
checkpoint_path=ckpt_path,
417+
hparams_file=hparams_path,
418+
strict=True,
419+
)
420+
421+
322422
def test_model_pickle(tmpdir):
323423
model = EvalModelTemplate()
324424
pickle.dumps(model)

tests/trainer/test_trainer.py

+60
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,66 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
104104
model_2.eval()
105105

106106

107+
@pytest.mark.parametrize('url_ckpt', [True, False])
108+
def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
109+
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
110+
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
111+
monkeypatch.setenv('TORCH_HOME', tmpdir)
112+
113+
model = EvalModelTemplate()
114+
# Extra layer
115+
model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim)
116+
117+
# logger file to get meta
118+
logger = tutils.get_default_logger(tmpdir)
119+
120+
# fit model
121+
trainer = Trainer(
122+
default_root_dir=tmpdir,
123+
max_epochs=1,
124+
logger=logger,
125+
checkpoint_callback=ModelCheckpoint(tmpdir),
126+
)
127+
result = trainer.fit(model)
128+
129+
# traning complete
130+
assert result == 1
131+
132+
# save model
133+
new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
134+
trainer.save_checkpoint(new_weights_path)
135+
136+
# load new model
137+
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
138+
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
139+
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' \
140+
if url_ckpt else new_weights_path
141+
142+
try:
143+
EvalModelTemplate.load_from_checkpoint(
144+
checkpoint_path=ckpt_path,
145+
hparams_file=hparams_path,
146+
)
147+
except Exception:
148+
failed = True
149+
else:
150+
failed = False
151+
152+
assert failed, "Model should not been loaded since the extra layer added."
153+
154+
failed = False
155+
try:
156+
EvalModelTemplate.load_from_checkpoint(
157+
checkpoint_path=ckpt_path,
158+
hparams_file=hparams_path,
159+
strict=False,
160+
)
161+
except Exception:
162+
failed = True
163+
164+
assert not failed, "Model should be loaded due to strict=False."
165+
166+
107167
@pytest.mark.parametrize(
108168
['schedule', 'expected'],
109169
[

0 commit comments

Comments
 (0)