Skip to content

Commit 3c3c644

Browse files
authored
Fix regex for non-pretraining checkpoints config retrieval
1 parent 843f1e2 commit 3c3c644

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

Diff for: fnet.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,16 @@ def get_config_from_statedict(state_dict,
177177
pad_token_id=0,
178178
layer_norm_eps=1e-12,
179179
dropout_rate=0.1):
180-
regex = re.compile(r'encoder.encoder.layer.\d+.feed_forward.weight')
181-
num_layers = len([key for key in state_dict.keys() if regex.search(key)])
182180
is_pretraining_checkpoint = 'mlm_output.weight' in state_dict.keys()
183181

184182
def prepare(key):
185183
if is_pretraining_checkpoint:
186184
return f"encoder.{key}"
187185
return key
188186

187+
regex = re.compile(prepare(r'encoder.layer.\d+.feed_forward.weight'))
188+
num_layers = len([key for key in state_dict.keys() if regex.search(key)])
189+
189190
return {
190191
"num_hidden_layers": num_layers,
191192
"vocab_size": state_dict[prepare('embeddings.word_embeddings.weight')].shape[0],

0 commit comments

Comments
 (0)