File tree 1 file changed +3
-2
lines changed
1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -177,15 +177,16 @@ def get_config_from_statedict(state_dict,
177
177
pad_token_id = 0 ,
178
178
layer_norm_eps = 1e-12 ,
179
179
dropout_rate = 0.1 ):
180
- regex = re .compile (r'encoder.encoder.layer.\d+.feed_forward.weight' )
181
- num_layers = len ([key for key in state_dict .keys () if regex .search (key )])
182
180
is_pretraining_checkpoint = 'mlm_output.weight' in state_dict .keys ()
183
181
184
182
def prepare (key ):
185
183
if is_pretraining_checkpoint :
186
184
return f"encoder.{ key } "
187
185
return key
188
186
187
+ regex = re .compile (prepare (r'encoder.layer.\d+.feed_forward.weight' ))
188
+ num_layers = len ([key for key in state_dict .keys () if regex .search (key )])
189
+
189
190
return {
190
191
"num_hidden_layers" : num_layers ,
191
192
"vocab_size" : state_dict [prepare ('embeddings.word_embeddings.weight' )].shape [0 ],
You can’t perform that action at this time.
0 commit comments