Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added strict=False for load_from_checkpoint #2819

Merged
merged 23 commits into from
Aug 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added GPU Usage Logger ([#2932](https://github.com/PyTorchLightning/pytorch-lightning/pull/2932))

- Added `strict=False` and `hparams_file` accepts dict for `load_from_checkpoint` ([#2819](https://github.com/PyTorchLightning/pytorch-lightning/pull/2819))

### Changed

- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def load_from_checkpoint(
*args,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
hparams_file: Optional[str] = None,
strict: bool = True,
**kwargs
):
r"""
Expand Down Expand Up @@ -71,6 +72,8 @@ def load_from_checkpoint(
If your model's `hparams` argument is :class:`~argparse.Namespace`
and .yaml file has hierarchical structure, you need to refactor your model to treat
`hparams` as :class:`~dict`.
strict: Whether to strictly enforce that the keys in :attr:`checkpoint_path` match the keys
returned by this module's state dict. Default: `True`.
hparam_overrides: A dictionary with keys to override in the hparams
kwargs: Any keyword args needed to init the model.

Expand Down Expand Up @@ -133,11 +136,11 @@ def load_from_checkpoint(
# override the hparams with values that were passed in
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)

model = cls._load_model_state(checkpoint, *args, **kwargs)
model = cls._load_model_state(checkpoint, strict=strict, *args, **kwargs)
return model

@classmethod
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, *cls_args, **cls_kwargs):
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls).parameters.keys()
# pass in the values we saved automatically
Expand Down Expand Up @@ -172,7 +175,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):

model = cls(*cls_args, **cls_kwargs)
# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['state_dict'], strict=strict)

# give model a chance to load something
model.on_load_checkpoint(checkpoint)
Expand Down
100 changes: 100 additions & 0 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging as log
import os
import pickle
import functools

import cloudpickle
import pytest
Expand Down Expand Up @@ -319,6 +320,105 @@ def test_model_saving_loading(tmpdir):
assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1


@pytest.mark.parametrize('url_ckpt', [True, False])
def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

model = EvalModelTemplate()
# Extra layer
model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim)

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
)
result = trainer.fit(model)

# traning complete
assert result == 1

# save model
new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
trainer.save_checkpoint(new_weights_path)

# load new model
hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml')
hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}'
ckpt_path = hparams_url if url_ckpt else new_weights_path

EvalModelTemplate.load_from_checkpoint(
checkpoint_path=ckpt_path,
hparams_file=hparams_path,
strict=False,
)

with pytest.raises(RuntimeError, match=r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
EvalModelTemplate.load_from_checkpoint(
checkpoint_path=ckpt_path,
hparams_file=hparams_path,
strict=True,
)


@pytest.mark.parametrize('url_ckpt', [True, False])
def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

model = EvalModelTemplate()

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
)
result = trainer.fit(model)

# traning complete
assert result == 1

# save model
new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
trainer.save_checkpoint(new_weights_path)

# load new model
hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), 'hparams.yaml')
hparams_url = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}'
ckpt_path = hparams_url if url_ckpt else new_weights_path

class CurrentModel(EvalModelTemplate):
def __init__(self):
super().__init__()
self.c_d3 = torch.nn.Linear(7, 7)

CurrentModel.load_from_checkpoint(
checkpoint_path=ckpt_path,
hparams_file=hparams_path,
strict=False,
)

with pytest.raises(RuntimeError, match=r'Missing key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
CurrentModel.load_from_checkpoint(
checkpoint_path=ckpt_path,
hparams_file=hparams_path,
strict=True,
)


def test_model_pickle(tmpdir):
model = EvalModelTemplate()
pickle.dumps(model)
Expand Down
60 changes: 60 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,66 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
model_2.eval()


@pytest.mark.parametrize('url_ckpt', [True, False])
def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)

model = EvalModelTemplate()
# Extra layer
model.c_d3 = torch.nn.Linear(model.hidden_dim, model.hidden_dim)

# logger file to get meta
logger = tutils.get_default_logger(tmpdir)

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger,
checkpoint_callback=ModelCheckpoint(tmpdir),
)
result = trainer.fit(model)

# traning complete
assert result == 1

# save model
new_weights_path = os.path.join(tmpdir, 'save_test.ckpt')
trainer.save_checkpoint(new_weights_path)

# load new model
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
ckpt_path = f'http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}' \
if url_ckpt else new_weights_path

try:
EvalModelTemplate.load_from_checkpoint(
checkpoint_path=ckpt_path,
hparams_file=hparams_path,
)
except Exception:
failed = True
else:
failed = False

assert failed, "Model should not been loaded since the extra layer added."

failed = False
try:
EvalModelTemplate.load_from_checkpoint(
checkpoint_path=ckpt_path,
hparams_file=hparams_path,
strict=False,
)
except Exception:
failed = True

assert not failed, "Model should be loaded due to strict=False."


@pytest.mark.parametrize(
['schedule', 'expected'],
[
Expand Down