forked from zcdliuwei/RetinaNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathweight_npy_transfor_h5.py
42 lines (36 loc) · 1010 Bytes
/
weight_npy_transfor_h5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
# @Time : 2020/4/16 21:27
# @Author : Suke0
# @Email : [email protected]
# @File : weight_npy_transfor_h5.py
# @Software: PyCharm
import numpy as np
def transfor_weight(model):
vars = model.variables
# for var in model.variables:
# print(var.name.split('/')[-1]+'__'+str(var.shape))
# pass
weights = np.load('./weight/weight.npy', allow_pickle=True)
for v in weights:
print(v.shape)
pass
arr = []
for v in vars:
vname = v.name.split('/')[-1]
if 'kernel' in vname:
print(vname + '__' + str(v.shape))
arr.append(v)
pass
pass
for v in vars:
vname = v.name.split('/')[-1]
if 'kernel' not in vname:
print(vname + '__' + str(v.shape))
arr.append(v)
pass
pass
for v, v1 in zip(arr, weights):
v.assign(v1)
pass
model.save_weights("./weight/resnet50retinanet_coco.h5")
pass