Skip to content

Commit e855f91

Browse files
vinx13wweic
authored andcommitted
[Relay, TOPI] Deformable conv2d (apache#2908)
* [Relay, TOPI] Add deformable conv2d * Moved to op level2 * Fix lint * Moved to level2 & bug fix * Update comments * Disabled flaky test of conv2d
1 parent 32d0731 commit e855f91

File tree

17 files changed

+821
-2
lines changed

17 files changed

+821
-2
lines changed

include/tvm/relay/attrs/nn.h

+61
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,67 @@ struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
456456
}
457457
};
458458

459+
460+
/*! \brief Attributes for DeformableConv2D operator */
461+
struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
462+
Array<IndexExpr> strides;
463+
Array<IndexExpr> padding;
464+
Array<IndexExpr> dilation;
465+
int deformable_groups;
466+
int groups;
467+
IndexExpr channels;
468+
Array<IndexExpr> kernel_size;
469+
std::string data_layout;
470+
std::string kernel_layout;
471+
std::string out_layout;
472+
DataType out_dtype;
473+
474+
TVM_DECLARE_ATTRS(DeformableConv2DAttrs, "relay.attrs.DeformableConv2DAttrs") {
475+
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
476+
.describe("Specifies the strides of the convolution.");
477+
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
478+
.describe("If padding is non-zero, then the input is implicitly zero-padded"
479+
"on both sides for padding number of points");
480+
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
481+
.describe("Specifies the dilation rate to use for dilated convolution.");
482+
TVM_ATTR_FIELD(deformable_groups).set_default(1)
483+
.describe("Controls the connections between inputs and offsets."
484+
"Input channels are partitioned into multiple deformable groups. Offsets"
485+
"are shared across input channels in the same deformable group.");
486+
TVM_ATTR_FIELD(groups).set_default(1)
487+
.describe("Controls the connections between inputs and outputs."
488+
"At groups=1, all inputs are convolved to all outputs."
489+
"At groups=2, the operation becomes equivalent to having two convolution"
490+
"layers side by side, each seeing half the input channels, and producing"
491+
"half the output channels, and both subsequently concatenated.");
492+
TVM_ATTR_FIELD(channels)
493+
.describe("The number of output channels in the convolution."
494+
" If it is not set, inferred by shape of the weight.")
495+
.set_default(NullValue<IndexExpr>());
496+
TVM_ATTR_FIELD(kernel_size)
497+
.describe("Specifies the dimensions of the convolution window.")
498+
.set_default(NullValue<Array<IndexExpr> >());
499+
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
500+
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
501+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
502+
"dimensions respectively. Convolution is applied on the 'H' and"
503+
"'W' dimensions.");
504+
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW")
505+
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
506+
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
507+
"dimensions respectively.");
508+
TVM_ATTR_FIELD(out_layout).set_default("")
509+
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
510+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
511+
"dimensions respectively. Default to be same as input layout.");
512+
513+
// use 0 bits to indicate none.
514+
TVM_ATTR_FIELD(out_dtype)
515+
.set_default(NullValue<DataType>())
516+
.describe("Output data type, set to explicit type under mixed precision setting");
517+
}
518+
};
519+
459520
} // namespace relay
460521
} // namespace tvm
461522
#endif // TVM_RELAY_ATTRS_NN_H_

python/tvm/autotvm/task/relay_integration.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def extract_from_program(func, params, ops, target, target_host=None):
5353
topi.nn.group_conv2d_nchw],
5454
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
5555
tvm.relay.op.nn.dense: [topi.nn.dense],
56+
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
5657
}
5758

5859
topi_funcs = []
@@ -126,6 +127,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
126127
topi.nn.group_conv2d_nchw],
127128
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
128129
tvm.relay.op.nn.dense: [topi.nn.dense],
130+
tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
129131
}
130132

131133
topi_funcs = []

python/tvm/autotvm/task/topi_integration.py

+11
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(self):
6868
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
6969
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
7070
topi.nn.dense: "topi_nn_dense",
71+
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
7172
}
7273

7374
self.topi_to_schedule = {
@@ -78,6 +79,7 @@ def __init__(self):
7879
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
7980
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
8081
topi.nn.dense: [topi.generic.schedule_dense],
82+
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
8183
}
8284

8385
self._register_tracing()
@@ -172,6 +174,15 @@ def _topi_nn_dense(*args, **kwargs):
172174
return s, [data, weight, bias, C]
173175
return s, [data, weight, C]
174176

177+
@register("topi_nn_deformable_conv2d_nchw")
178+
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
179+
assert not kwargs, "Do not support kwargs in template function call"
180+
args = deserialize_args(args)
181+
A, Offset, W = args[:3]
182+
C = topi.nn.deformable_conv2d_nchw(*args, **kwargs)
183+
s = topi.generic.schedule_deformable_conv2d_nchw([C])
184+
return s, [A, Offset, W, C]
185+
175186
def reset(self, wanted_topi_funcs):
176187
"""Reset task collections
177188

python/tvm/relay/frontend/mxnet.py

+20
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,25 @@ def _mx_smooth_l1(inputs, attrs):
603603
_op.abs(inputs[0]) - _expr.const(0.5 / scalar_sq))
604604

605605

606+
def _mx_deformable_convolution(inputs, attrs):
607+
new_attrs = {}
608+
assert attrs.get_bool("no_bias")
609+
new_attrs["kernel_size"] = attrs.get_int_tuple("kernel")
610+
new_attrs["strides"] = attrs.get_int_tuple("stride")
611+
new_attrs["padding"] = attrs.get_int_tuple("pad")
612+
new_attrs["dilation"] = attrs.get_int_tuple("dilate")
613+
new_attrs["channels"] = attrs.get_int("num_filter")
614+
new_attrs["deformable_groups"] = attrs.get_int("num_deformable_group", 1)
615+
new_attrs["groups"] = attrs.get_int("num_group", 1)
616+
assert attrs.get_str("layout", "NCHW") == "NCHW", "Deformable conv2d only supports NCHW layout"
617+
use_bias = not attrs.get_bool("no_bias", False)
618+
res = _op.nn.deformable_conv2d(inputs[0], inputs[1], inputs[2], **new_attrs)
619+
if use_bias:
620+
assert len(inputs) == 4
621+
res = _op.nn.bias_add(res, inputs[3])
622+
return res
623+
624+
606625
# Note: due to attribute conversion constraint
607626
# ops in the identity set must be attribute free
608627
_identity_list = [
@@ -748,6 +767,7 @@ def _mx_smooth_l1(inputs, attrs):
748767
"_contrib_Proposal" : _mx_proposal,
749768
"_contrib_MultiProposal" : _mx_proposal,
750769
"_contrib_box_nms" : _mx_box_nms,
770+
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
751771
# List of missing operators that are present in NNVMv1
752772
# TODO(tvm-tvm): support all operators.
753773
#

python/tvm/relay/op/nn/_nn.py

+23
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,26 @@ def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
426426

427427
reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
428428
OpPattern.OUT_ELEMWISE_FUSABLE)
429+
430+
@reg.register_compute("nn.deformable_conv2d")
431+
def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
432+
"""Compute definition of deformable_conv2d"""
433+
padding = get_const_tuple(attrs.padding)
434+
strides = get_const_tuple(attrs.strides)
435+
dilation = get_const_tuple(attrs.dilation)
436+
deformable_groups = attrs.deformable_groups
437+
groups = attrs.groups
438+
out_dtype = attrs.out_dtype
439+
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
440+
with target:
441+
out = topi.nn.deformable_conv2d_nchw(inputs[0], inputs[1], inputs[2], strides, padding,
442+
dilation, deformable_groups, groups, out_dtype)
443+
return [out]
444+
445+
@reg.register_schedule("nn.deformable_conv2d")
446+
def schedule_deformable_conv2d(attrs, outs, target):
447+
"""Schedule definition of deformable_conv2d"""
448+
with target:
449+
return topi.generic.schedule_deformable_conv2d_nchw(outs)
450+
451+
reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

python/tvm/relay/op/nn/nn.py

+73
Original file line numberDiff line numberDiff line change
@@ -1105,3 +1105,76 @@ def contrib_conv2d_winograd_nnpack_weight_transform(weight,
11051105
"""
11061106
return _make.contrib_conv2d_winograd_nnpack_weight_transform(
11071107
weight, convolution_algorithm, out_dtype)
1108+
1109+
1110+
def deformable_conv2d(data,
1111+
offset,
1112+
weight,
1113+
strides=(1, 1),
1114+
padding=(0, 0),
1115+
dilation=(1, 1),
1116+
deformable_groups=1,
1117+
groups=1,
1118+
channels=None,
1119+
kernel_size=None,
1120+
data_layout='NCHW',
1121+
kernel_layout='OIHW',
1122+
out_layout='',
1123+
out_dtype=''):
1124+
r""" Deformable 2d convolution.
1125+
1126+
The deformable convolution operation is described in https://arxiv.org/abs/1703.06211
1127+
1128+
Parameters
1129+
----------
1130+
data : tvm.relay.Expr
1131+
The input data to the operator.
1132+
1133+
offset : tvm.relay.Expr
1134+
The offset expressions.
1135+
1136+
weight : tvm.relay.Expr
1137+
The weight expressions.
1138+
1139+
strides : tuple of int, optional
1140+
The strides of convoltution.
1141+
1142+
padding : tuple of int, optional
1143+
The padding of convolution on both sides of inputs before convolution.
1144+
1145+
dilation : tuple of int, optional
1146+
Specifies the dilation rate to be used for dilated convolution.
1147+
1148+
deformable_groups : int, optional
1149+
Number of deformable groups.
1150+
1151+
groups : int, optional
1152+
Number of groups for grouped convolution.
1153+
1154+
channels : int, optional
1155+
Number of output channels of this convolution.
1156+
1157+
kernel_size : tuple of int, optional
1158+
The spatial of the convolution kernel.
1159+
1160+
data_layout : str, optional
1161+
Layout of the input.
1162+
1163+
kernel_layout : str, optional
1164+
Layout of the weight.
1165+
1166+
out_layout : str, optional
1167+
Layout of the output, by default, out_layout is the same as data_layout
1168+
1169+
out_dtype : str, optional
1170+
Specifies the output data type for mixed precision conv2d.
1171+
1172+
Returns
1173+
-------
1174+
result : tvm.relay.Expr
1175+
The computed result.
1176+
1177+
"""
1178+
return _make.deformable_conv2d(data, offset, weight, strides, padding, dilation,
1179+
deformable_groups, groups, channels, kernel_size, data_layout,
1180+
kernel_layout, out_layout, out_dtype)

0 commit comments

Comments
 (0)