Skip to content

Commit 33392db

Browse files
author
local
committed
Add CondInst Onnx support
1 parent 38c466c commit 33392db

File tree

1 file changed

+78
-5
lines changed

1 file changed

+78
-5
lines changed

onnx/export_model_to_onnx.py

+78-5
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,27 @@
4040
from detectron2.modeling import ProposalNetwork
4141

4242
from adet.config import get_cfg
43-
from adet.modeling import FCOS, BlendMask
43+
from adet.modeling import FCOS, BlendMask, BAText, MEInst, condinst, SOLOv2
44+
from adet.modeling.condinst.mask_branch import MaskBranch
45+
46+
def patch_condinst(cfg, model, output_names):
47+
def forward(self, tensor):
48+
images = None
49+
gt_instances = None
50+
mask_feats = None
51+
proposals = None
52+
53+
features = self.backbone(tensor)
54+
#return features
55+
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances, self.controller)
56+
#return proposals
57+
mask_feats, sem_losses = self.mask_branch(features, gt_instances)
58+
#return mask_feats
59+
return mask_feats, proposals
60+
61+
model.forward = types.MethodType(forward, model)
62+
63+
#output tensor naming [optional]
4464

4565
def patch_blendmask(cfg, model, output_names):
4666
def forward(self, tensor):
@@ -54,7 +74,8 @@ def forward(self, tensor):
5474
return basis_out["bases"], proposals
5575

5676
model.forward = types.MethodType(forward, model)
57-
# output
77+
78+
#output tensor naming [optional]
5879
output_names.extend(["bases"])
5980
for item in ["logits", "bbox_reg", "centerness", "top_feats"]:
6081
for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)):
@@ -71,7 +92,8 @@ def forward(self, tensor):
7192
return proposals
7293

7394
model.forward = types.MethodType(forward, model)
74-
# output
95+
96+
#output tensor naming [optional]
7597
for item in ["logits", "bbox_reg", "centerness"]:
7698
for l in range(len(cfg.MODEL.FCOS.FPN_STRIDES)):
7799
fpn_name = "P{}".format(3 + l)
@@ -95,7 +117,7 @@ def patch_fcos_head(cfg, fcos_head):
95117
"share": (cfg.MODEL.FCOS.NUM_SHARE_CONVS,
96118
False)}
97119

98-
# step 2. seperate module
120+
# step 2. separate module
99121
for l in range(fcos_head.num_levels):
100122
for head in head_configs:
101123
tower = []
@@ -137,6 +159,50 @@ def fcos_head_forward(self, x, top_module=None, yield_bbox_towers=False):
137159

138160
fcos_head.forward = types.MethodType(fcos_head_forward, fcos_head)
139161

162+
def upsample(tensor, factor): # aligned_bilinear in adet/utils/comm.py is not onnx-friendly
163+
assert tensor.dim() == 4
164+
assert factor >= 1
165+
assert int(factor) == factor
166+
167+
if factor == 1:
168+
return tensor
169+
170+
h, w = tensor.size()[2:]
171+
oh = factor * h
172+
ow = factor * w
173+
tensor = F.interpolate(
174+
tensor, size=(oh, ow),
175+
mode='nearest',
176+
)
177+
return tensor
178+
179+
def patch_mask_branch(cfg, mask_branch):
180+
def mask_branch_forward(self, features, gt_instances=None):
181+
for i, f in enumerate(self.in_features):
182+
if i == 0:
183+
x = self.refine[i](features[f])
184+
else:
185+
x_p = self.refine[i](features[f])
186+
187+
target_h, target_w = x.size()[2:]
188+
h, w = x_p.size()[2:]
189+
assert target_h % h == 0
190+
assert target_w % w == 0
191+
factor_h, factor_w = target_h // h, target_w // w
192+
assert factor_h == factor_w
193+
x_p = upsample(x_p, factor_h)
194+
x = x + x_p
195+
196+
mask_feats = self.tower(x)
197+
198+
if self.num_outputs == 0:
199+
mask_feats = mask_feats[:, :self.num_outputs]
200+
201+
losses = {}
202+
return mask_feats, losses
203+
204+
mask_branch.forward = types.MethodType(mask_branch_forward, mask_branch)
205+
140206
def main():
141207
parser = argparse.ArgumentParser(description="Export model to the onnx format")
142208
parser.add_argument(
@@ -198,6 +264,9 @@ def main():
198264
input_names = ["input_image"]
199265
dummy_input = torch.zeros((1, 3, height, width)).to(cfg.MODEL.DEVICE)
200266
output_names = []
267+
if isinstance(model, condinst.CondInst):
268+
patch_condinst(cfg, model, output_names)
269+
201270
if isinstance(model, BlendMask):
202271
patch_blendmask(cfg, model, output_names)
203272

@@ -209,14 +278,18 @@ def main():
209278
patch_fcos(cfg, model.proposal_generator)
210279
patch_fcos_head(cfg, model.proposal_generator.fcos_head)
211280

281+
if hasattr(model, 'mask_branch'):
282+
if isinstance(model.mask_branch, MaskBranch):
283+
patch_mask_branch(cfg, model.mask_branch) # replace aligned_bilinear with nearest upsample
284+
212285
torch.onnx.export(
213286
model,
214287
dummy_input,
215288
args.output,
216289
verbose=True,
217290
input_names=input_names,
218291
output_names=output_names,
219-
keep_initializers_as_inputs=True
292+
keep_initializers_as_inputs=True,
220293
)
221294

222295
logger.info("Done. The onnx model is saved into {}.".format(args.output))

0 commit comments

Comments
 (0)