Skip to content

Commit bbcbcbb

Browse files
yTakatsukasawweic
authored andcommitted
Consistent result of DetectLinearEquation() when an empy vars is passed (apache#2860)
1 parent 8fde500 commit bbcbcbb

File tree

3 files changed

+27
-22
lines changed

3 files changed

+27
-22
lines changed

src/arithmetic/detect_linear_equation.cc

+13-17
Original file line numberDiff line numberDiff line change
@@ -127,25 +127,21 @@ Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
127127
Expr base = e;
128128
Array<Expr> coeff;
129129

130-
if (0 == vars.size()) {
131-
coeff.push_back(make_const(Int(32), 1));
132-
} else {
133-
for (Var v : vars) {
134-
LinearEqEntry ret;
135-
if (!LinearEqDetector(v).Detect(base, &ret)) {
136-
return Array<Expr>();
137-
}
138-
coeff.push_back(ret.coeff);
139-
base = std::move(ret.base);
130+
for (Var v : vars) {
131+
LinearEqEntry ret;
132+
if (!LinearEqDetector(v).Detect(base, &ret)) {
133+
return Array<Expr>();
140134
}
135+
coeff.push_back(ret.coeff);
136+
base = std::move(ret.base);
137+
}
141138

142-
std::unordered_set<const Variable*> vset;
143-
for (size_t i = vars.size(); i != 1; --i) {
144-
vset.insert(vars[i - 1].get());
145-
// The previous coeff contains the variable
146-
if (ExprUseVar(coeff[i - 2], vset)) {
147-
return Array<Expr>();
148-
}
139+
std::unordered_set<const Variable*> vset;
140+
for (size_t i = vars.size(); i > 1; --i) {
141+
vset.insert(vars[i - 1].get());
142+
// The previous coeff contains the variable
143+
if (ExprUseVar(coeff[i - 2], vset)) {
144+
return Array<Expr>();
149145
}
150146
}
151147
coeff.push_back(base);

src/pass/inject_copy_intrin.cc

+6-5
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class CopyIntrinInjector : public IRMutator {
3939
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
4040
using namespace arith;
4141
Stmt body = stmt;
42-
bool is_single_point_copy = false;
4342

4443
// strip the loops
4544
std::vector<const For*> loops;
@@ -60,7 +59,6 @@ class CopyIntrinInjector : public IRMutator {
6059
const Cast* cast = store->value.as<Cast>();
6160
const Load* load = store->value.as<Load>();
6261
if (0 == loops.size()) {
63-
is_single_point_copy = true;
6462
CHECK(!has_cond);
6563
}
6664
// for now only support true condition matching
@@ -83,9 +81,8 @@ class CopyIntrinInjector : public IRMutator {
8381
arith::DetectLinearEquation(load->index, loop_vars);
8482
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
8583
Array<Expr> dst_shape;
86-
auto loop_var_size = loop_vars.size();
87-
if (is_single_point_copy) {
88-
loop_var_size = 1;
84+
const size_t loop_var_size = loop_vars.size();
85+
if (loop_var_size == 0) {
8986
dst_shape.push_back(make_const(Int(32), 1));
9087
} else {
9188
for (const For* op : loops) {
@@ -132,6 +129,10 @@ class CopyIntrinInjector : public IRMutator {
132129
CHECK_EQ(load_strides.size(), loop_var_size + 1);
133130
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
134131
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
132+
if (loop_var_size == 0) {
133+
src_strides.push_back(make_const(Int(32), 1));
134+
dst_strides.push_back(make_const(Int(32), 1));
135+
}
135136
Buffer dst = BufferNode::make(
136137
Var(store->buffer_var.node_),
137138
store->value.type(),

tests/python/unittest/test_arith_detect_linear_equation.py

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ def test_basic():
2020
m = tvm.arith.DetectLinearEquation(b * 7, [a])
2121
assert m[0].value == 0
2222

23+
m = tvm.arith.DetectLinearEquation(b * 7, [])
24+
assert len(m) == 1
25+
assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0
26+
2327
def test_multivariate():
2428
v = [tvm.var("v%d" % i) for i in range(4)]
2529
b = tvm.var("b")
@@ -42,6 +46,10 @@ def test_multivariate():
4246
assert(m[0].value == 0)
4347
assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
4448

49+
m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [])
50+
assert(len(m) == 1)
51+
assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
52+
4553
if __name__ == "__main__":
4654
test_basic()
4755
test_multivariate()

0 commit comments

Comments
 (0)