40
40
from detectron2 .modeling import ProposalNetwork
41
41
42
42
from adet .config import get_cfg
43
- from adet .modeling import FCOS , BlendMask
43
+ from adet .modeling import FCOS , BlendMask , BAText , MEInst , condinst , SOLOv2
44
+ from adet .modeling .condinst .mask_branch import MaskBranch
45
+
46
+ def patch_condinst (cfg , model , output_names ):
47
+ def forward (self , tensor ):
48
+ images = None
49
+ gt_instances = None
50
+ mask_feats = None
51
+ proposals = None
52
+
53
+ features = self .backbone (tensor )
54
+ #return features
55
+ proposals , proposal_losses = self .proposal_generator (images , features , gt_instances , self .controller )
56
+ #return proposals
57
+ mask_feats , sem_losses = self .mask_branch (features , gt_instances )
58
+ #return mask_feats
59
+ return mask_feats , proposals
60
+
61
+ model .forward = types .MethodType (forward , model )
62
+
63
+ #output tensor naming [optional]
44
64
45
65
def patch_blendmask (cfg , model , output_names ):
46
66
def forward (self , tensor ):
@@ -54,7 +74,8 @@ def forward(self, tensor):
54
74
return basis_out ["bases" ], proposals
55
75
56
76
model .forward = types .MethodType (forward , model )
57
- # output
77
+
78
+ #output tensor naming [optional]
58
79
output_names .extend (["bases" ])
59
80
for item in ["logits" , "bbox_reg" , "centerness" , "top_feats" ]:
60
81
for l in range (len (cfg .MODEL .FCOS .FPN_STRIDES )):
@@ -71,7 +92,8 @@ def forward(self, tensor):
71
92
return proposals
72
93
73
94
model .forward = types .MethodType (forward , model )
74
- # output
95
+
96
+ #output tensor naming [optional]
75
97
for item in ["logits" , "bbox_reg" , "centerness" ]:
76
98
for l in range (len (cfg .MODEL .FCOS .FPN_STRIDES )):
77
99
fpn_name = "P{}" .format (3 + l )
@@ -95,7 +117,7 @@ def patch_fcos_head(cfg, fcos_head):
95
117
"share" : (cfg .MODEL .FCOS .NUM_SHARE_CONVS ,
96
118
False )}
97
119
98
- # step 2. seperate module
120
+ # step 2. separate module
99
121
for l in range (fcos_head .num_levels ):
100
122
for head in head_configs :
101
123
tower = []
@@ -137,6 +159,50 @@ def fcos_head_forward(self, x, top_module=None, yield_bbox_towers=False):
137
159
138
160
fcos_head .forward = types .MethodType (fcos_head_forward , fcos_head )
139
161
162
+ def upsample (tensor , factor ): # aligned_bilinear in adet/utils/comm.py is not onnx-friendly
163
+ assert tensor .dim () == 4
164
+ assert factor >= 1
165
+ assert int (factor ) == factor
166
+
167
+ if factor == 1 :
168
+ return tensor
169
+
170
+ h , w = tensor .size ()[2 :]
171
+ oh = factor * h
172
+ ow = factor * w
173
+ tensor = F .interpolate (
174
+ tensor , size = (oh , ow ),
175
+ mode = 'nearest' ,
176
+ )
177
+ return tensor
178
+
179
+ def patch_mask_branch (cfg , mask_branch ):
180
+ def mask_branch_forward (self , features , gt_instances = None ):
181
+ for i , f in enumerate (self .in_features ):
182
+ if i == 0 :
183
+ x = self .refine [i ](features [f ])
184
+ else :
185
+ x_p = self .refine [i ](features [f ])
186
+
187
+ target_h , target_w = x .size ()[2 :]
188
+ h , w = x_p .size ()[2 :]
189
+ assert target_h % h == 0
190
+ assert target_w % w == 0
191
+ factor_h , factor_w = target_h // h , target_w // w
192
+ assert factor_h == factor_w
193
+ x_p = upsample (x_p , factor_h )
194
+ x = x + x_p
195
+
196
+ mask_feats = self .tower (x )
197
+
198
+ if self .num_outputs == 0 :
199
+ mask_feats = mask_feats [:, :self .num_outputs ]
200
+
201
+ losses = {}
202
+ return mask_feats , losses
203
+
204
+ mask_branch .forward = types .MethodType (mask_branch_forward , mask_branch )
205
+
140
206
def main ():
141
207
parser = argparse .ArgumentParser (description = "Export model to the onnx format" )
142
208
parser .add_argument (
@@ -198,6 +264,9 @@ def main():
198
264
input_names = ["input_image" ]
199
265
dummy_input = torch .zeros ((1 , 3 , height , width )).to (cfg .MODEL .DEVICE )
200
266
output_names = []
267
+ if isinstance (model , condinst .CondInst ):
268
+ patch_condinst (cfg , model , output_names )
269
+
201
270
if isinstance (model , BlendMask ):
202
271
patch_blendmask (cfg , model , output_names )
203
272
@@ -209,14 +278,18 @@ def main():
209
278
patch_fcos (cfg , model .proposal_generator )
210
279
patch_fcos_head (cfg , model .proposal_generator .fcos_head )
211
280
281
+ if hasattr (model , 'mask_branch' ):
282
+ if isinstance (model .mask_branch , MaskBranch ):
283
+ patch_mask_branch (cfg , model .mask_branch ) # replace aligned_bilinear with nearest upsample
284
+
212
285
torch .onnx .export (
213
286
model ,
214
287
dummy_input ,
215
288
args .output ,
216
289
verbose = True ,
217
290
input_names = input_names ,
218
291
output_names = output_names ,
219
- keep_initializers_as_inputs = True
292
+ keep_initializers_as_inputs = True ,
220
293
)
221
294
222
295
logger .info ("Done. The onnx model is saved into {}." .format (args .output ))
0 commit comments