Skip to content

Commit 2dca6ea

Browse files
awaelchliwilliamFalcon
authored and
atee
committed
Fix hparams loading for model that accepts *args (Lightning-AI#2911)
* fix hparams loading for model that accepts *args * add test case * changelog * pep * fix test Co-authored-by: William Falcon <[email protected]>
1 parent 188c6be commit 2dca6ea

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
123123

124124
- Fixed save apex scaler states ([#2828](https://github.com/PyTorchLightning/pytorch-lightning/pull/2828))
125125

126+
- Fixed a model loading issue with inheritance and variable positional arguments ([#2911](https://github.com/PyTorchLightning/pytorch-lightning/pull/2911))
127+
126128
- Fixed passing `non_blocking=True` when transferring a batch object that does not support it ([#2910](https://github.com/PyTorchLightning/pytorch-lightning/pull/2910))
127129

128130
## [0.8.5] - 2020-07-09

pytorch_lightning/core/saving.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,9 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
167167
cls_kwargs = {k: v for k, v in cls_kwargs.items() if k in cls_init_args_name}
168168

169169
# prevent passing positional arguments if class does not accept any
170-
if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
170+
if len(cls_spec.args) <= 1 and not cls_spec.varargs and not cls_spec.kwonlyargs:
171171
cls_args, cls_kwargs = [], {}
172+
172173
model = cls(*cls_args, **cls_kwargs)
173174
# load the state_dict on the model automatically
174175
model.load_state_dict(checkpoint['state_dict'])

tests/models/test_hparams.py

+26
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,29 @@ def __init__(self, batch_size=15):
497497
raw_checkpoint_path = _raw_checkpoint_path(trainer)
498498
model = LocalModel.load_from_checkpoint(raw_checkpoint_path, non_exist_kwarg=99)
499499
assert 'non_exist_kwarg' not in model.hparams
500+
501+
502+
class SuperClassPositionalArgs(EvalModelTemplate):
503+
504+
def __init__(self, hparams):
505+
super().__init__()
506+
self._hparams = None # pretend EvalModelTemplate did not call self.save_hyperparameters()
507+
self.hparams = hparams
508+
509+
510+
class SubClassVarArgs(SuperClassPositionalArgs):
511+
""" Loading this model should accept hparams and init in the super class """
512+
def __init__(self, *args, **kwargs):
513+
super().__init__(*args, **kwargs)
514+
515+
516+
def test_args(tmpdir):
517+
""" Test for inheritance: super class takes positional arg, subclass takes varargs. """
518+
hparams = dict(test=1)
519+
model = SubClassVarArgs(hparams)
520+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
521+
trainer.fit(model)
522+
523+
raw_checkpoint_path = _raw_checkpoint_path(trainer)
524+
model = SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)
525+
assert model.hparams == hparams

0 commit comments

Comments
 (0)