Skip to content

Commit 2019081

Browse files
jbschirattiwilliamFalcon
authored andcommitted
Fix? _load_model_state cleanup
1 parent 30b171b commit 2019081

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

pytorch_lightning/core/saving.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -136,37 +136,43 @@ def load_from_checkpoint(
136136
return model
137137

138138
@classmethod
139-
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args, **cls_kwargs):
139+
def _load_model_state(cls, checkpoint: Dict[str, Any], *cls_args_new, **cls_kwargs_new):
140140
cls_spec = inspect.getfullargspec(cls.__init__)
141141
cls_init_args_name = inspect.signature(cls).parameters.keys()
142142
# pass in the values we saved automatically
143143
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
144-
model_args = {}
144+
cls_kwargs_old = {}
145145

146-
# add some back compatibility, the actual one shall be last
147-
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
148-
if hparam_key in checkpoint:
149-
model_args.update(checkpoint[hparam_key])
146+
# 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
147+
for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS:
148+
if _old_hparam_key in checkpoint:
149+
cls_kwargs_old.update({_old_hparam_key: checkpoint[_old_hparam_key]})
150150

151-
model_args = _convert_loaded_hparams(model_args, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
151+
# 2. Try to restore model hparams from checkpoint using the new key
152+
_new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY
153+
cls_kwargs_old.update({_new_hparam_key: checkpoint[_new_hparam_key]})
152154

153-
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
155+
# 3. Ensure that `cls_kwargs_old` has the right type
156+
cls_kwargs_old = _convert_loaded_hparams(cls_kwargs_old, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
154157

155-
if args_name == 'kwargs':
156-
# in case the class cannot take any extra argument filter only the possible
157-
cls_kwargs.update(**model_args)
158-
elif args_name:
159-
if args_name in cls_init_args_name:
160-
cls_kwargs.update({args_name: model_args})
158+
# 4. Update cls_kwargs_new with cls_kwargs_old
159+
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
160+
if args_name and args_name in cls_init_args_name:
161+
cls_kwargs_new.update({args_name: cls_kwargs_old})
162+
else:
163+
cls_kwargs_new.update(cls_kwargs_old)
161164

162165
if not cls_spec.varkw:
163166
# filter kwargs according to class init unless it allows any argument via kwargs
164-
cls_kwargs = {k: v for k, v in cls_kwargs.items() if k in cls_init_args_name}
167+
cls_kwargs_new = {k: v for k, v in cls_kwargs_new.items() if k in cls_init_args_name}
165168

166169
# prevent passing positional arguments if class does not accept any
167170
if len(cls_spec.args) <= 1 and not cls_spec.kwonlyargs:
168-
cls_args, cls_kwargs = [], {}
169-
model = cls(*cls_args, **cls_kwargs)
171+
_cls_args_new, _cls_kwargs_new = [], {}
172+
else:
173+
_cls_args_new, _cls_kwargs_new = cls_args_new, cls_kwargs_new
174+
175+
model = cls(*cls_args_new, **cls_kwargs_new)
170176
# load the state_dict on the model automatically
171177
model.load_state_dict(checkpoint['state_dict'])
172178

tests/models/test_restore.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def test_running_test_pretrained_model_cpu(tmpdir):
152152

153153
def test_load_model_from_checkpoint(tmpdir):
154154
"""Verify test() on pretrained model."""
155+
from pytorch_lightning import LightningModule
156+
155157
hparams = EvalModelTemplate.get_default_hparams()
156158
model = EvalModelTemplate(**hparams)
157159

@@ -174,20 +176,23 @@ def test_load_model_from_checkpoint(tmpdir):
174176

175177
# load last checkpoint
176178
last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1]
179+
180+
# Since `EvalModelTemplate` has `_save_hparams = True` by default, check that ckpt has hparams
181+
ckpt = torch.load(last_checkpoint)
182+
assert LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), 'module_arguments missing from checkpoints'
183+
184+
# Ensure that model can be correctly restored from checkpoint
177185
pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint)
178186

179-
# test that hparams loaded correctly
180187
for k, v in hparams.items():
181188
assert getattr(pretrained_model, k) == v
182189

183-
# assert weights are the same
184190
for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()):
185191
assert torch.all(torch.eq(old_p, new_p)), 'loaded weights are not the same as the saved weights'
186192

193+
# Check `test` on pretrained model:
187194
new_trainer = Trainer(**trainer_options)
188195
new_trainer.test(pretrained_model)
189-
190-
# test we have good test accuracy
191196
tutils.assert_ok_model_acc(new_trainer)
192197

193198

0 commit comments

Comments
 (0)