@@ -136,37 +136,43 @@ def load_from_checkpoint(
136
136
return model
137
137
138
138
@classmethod
139
- def _load_model_state (cls , checkpoint : Dict [str , Any ], * cls_args , ** cls_kwargs ):
139
+ def _load_model_state (cls , checkpoint : Dict [str , Any ], * cls_args_new , ** cls_kwargs_new ):
140
140
cls_spec = inspect .getfullargspec (cls .__init__ )
141
141
cls_init_args_name = inspect .signature (cls ).parameters .keys ()
142
142
# pass in the values we saved automatically
143
143
if cls .CHECKPOINT_HYPER_PARAMS_KEY in checkpoint :
144
- model_args = {}
144
+ cls_kwargs_old = {}
145
145
146
- # add some back compatibility, the actual one shall be last
147
- for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + ( cls . CHECKPOINT_HYPER_PARAMS_KEY ,) :
148
- if hparam_key in checkpoint :
149
- model_args .update (checkpoint [hparam_key ] )
146
+ # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys
147
+ for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS :
148
+ if _old_hparam_key in checkpoint :
149
+ cls_kwargs_old .update ({ _old_hparam_key : checkpoint [_old_hparam_key ]} )
150
150
151
- model_args = _convert_loaded_hparams (model_args , checkpoint .get (cls .CHECKPOINT_HYPER_PARAMS_TYPE ))
151
+ # 2. Try to restore model hparams from checkpoint using the new key
152
+ _new_hparam_key = cls .CHECKPOINT_HYPER_PARAMS_KEY
153
+ cls_kwargs_old .update ({_new_hparam_key : checkpoint [_new_hparam_key ]})
152
154
153
- args_name = checkpoint .get (cls .CHECKPOINT_HYPER_PARAMS_NAME )
155
+ # 3. Ensure that `cls_kwargs_old` has the right type
156
+ cls_kwargs_old = _convert_loaded_hparams (cls_kwargs_old , checkpoint .get (cls .CHECKPOINT_HYPER_PARAMS_TYPE ))
154
157
155
- if args_name == 'kwargs' :
156
- # in case the class cannot take any extra argument filter only the possible
157
- cls_kwargs . update ( ** model_args )
158
- elif args_name :
159
- if args_name in cls_init_args_name :
160
- cls_kwargs .update ({ args_name : model_args } )
158
+ # 4. Update cls_kwargs_new with cls_kwargs_old
159
+ args_name = checkpoint . get ( cls . CHECKPOINT_HYPER_PARAMS_NAME )
160
+ if args_name and args_name in cls_init_args_name :
161
+ cls_kwargs_new . update ({ args_name : cls_kwargs_old })
162
+ else :
163
+ cls_kwargs_new .update (cls_kwargs_old )
161
164
162
165
if not cls_spec .varkw :
163
166
# filter kwargs according to class init unless it allows any argument via kwargs
164
- cls_kwargs = {k : v for k , v in cls_kwargs .items () if k in cls_init_args_name }
167
+ cls_kwargs_new = {k : v for k , v in cls_kwargs_new .items () if k in cls_init_args_name }
165
168
166
169
# prevent passing positional arguments if class does not accept any
167
170
if len (cls_spec .args ) <= 1 and not cls_spec .kwonlyargs :
168
- cls_args , cls_kwargs = [], {}
169
- model = cls (* cls_args , ** cls_kwargs )
171
+ _cls_args_new , _cls_kwargs_new = [], {}
172
+ else :
173
+ _cls_args_new , _cls_kwargs_new = cls_args_new , cls_kwargs_new
174
+
175
+ model = cls (* cls_args_new , ** cls_kwargs_new )
170
176
# load the state_dict on the model automatically
171
177
model .load_state_dict (checkpoint ['state_dict' ])
172
178
0 commit comments