Skip to content

Commit fc58ac3

Browse files
icemelonwweic
authored andcommitted
[Relay][OP] Fix bias_add default axis (apache#2829)
* Fix bias add default axis * update * Fix canonicalize ops for bias_add
1 parent 2e5a5c6 commit fc58ac3

File tree

5 files changed

+12
-8
lines changed

5 files changed

+12
-8
lines changed

python/tvm/relay/frontend/mxnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _mx_fully_connected(inputs, attrs):
3434
res = _op.nn.dense(inputs[0], inputs[1], units=units)
3535
if use_bias:
3636
assert len(inputs) == 3
37-
res = _op.nn.bias_add(res, inputs[2])
37+
res = _op.nn.bias_add(res, inputs[2], axis=-1)
3838
return res
3939

4040

@@ -413,7 +413,7 @@ def _mx_batch_dot(inputs, attrs):
413413
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
414414
if transpose_b is False:
415415
b = _op.transpose(b, axes=[0, 2, 1])
416-
return _op.batch_matmul(a, b)
416+
return _op.nn.batch_matmul(a, b)
417417

418418

419419
def _mx_arange(inputs, attrs):

python/tvm/relay/testing/inception_v3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def get_net(batch_size,
248248

249249
flatten = relay.nn.batch_flatten(pool)
250250
fc1 = relay.nn.dense(flatten, relay.var("fc1_weight"), units=num_classes)
251-
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"))
251+
fc1 = relay.nn.bias_add(fc1, relay.var("fc2_bias"), axis=-1)
252252
inception_v3 = relay.nn.softmax(data=fc1)
253253
args = relay.ir_pass.free_vars(inception_v3)
254254
return relay.Function(args, inception_v3)

python/tvm/relay/testing/layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,5 +134,5 @@ def dense_add_bias(data, weight=None, bias=None, units=None, **kwargs):
134134
if not bias:
135135
bias = relay.var(name + "_bias")
136136
data = relay.nn.dense(data, weight, units, **kwargs)
137-
data = relay.nn.bias_add(data, bias)
137+
data = relay.nn.bias_add(data, bias, axis=-1)
138138
return data

python/tvm/relay/testing/mlp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def get_net(batch_size,
5050
dtype=dtype)
5151
data = relay.nn.batch_flatten(data)
5252
fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
53-
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"))
53+
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"), axis=-1)
5454
act1 = relay.nn.relu(fc1)
5555
fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64)
56-
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"))
56+
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"), axis=-1)
5757
act2 = relay.nn.relu(fc2)
5858
fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=num_classes)
59-
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"))
59+
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1)
6060
mlp = relay.nn.softmax(data=fc3)
6161
args = relay.ir_pass.free_vars(mlp)
6262
return relay.Function(args, mlp)

src/relay/pass/canonicalize_ops.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ class BiasAddSimplifier : public ExprMutator {
2424

2525
auto ttype = n->args[0]->type_as<TensorTypeNode>();
2626
size_t n_dim = ttype->shape.size();
27-
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {param->axis});
27+
int axis = param->axis;
28+
if (axis < 0) {
29+
axis += n_dim;
30+
}
31+
Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {axis});
2832
Expr ret = Add(call->args[0], expanded_bias);
2933
ret->checked_type_ = n->checked_type_;
3034
return ret;

0 commit comments

Comments
 (0)