Skip to content

Commit 0e05cbf

Browse files
zhiicswweic
authored andcommitted
[RELAY][Frontend][TF] decompile tf control flow (apache#2830)
* decompile tf control flow * Add docs * remove import relay * move tests under tensorflow frontend * minor fix
1 parent 72a8924 commit 0e05cbf

File tree

3 files changed

+630
-12
lines changed

3 files changed

+630
-12
lines changed

python/tvm/relay/frontend/tensorflow.py

+317-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77
import warnings
8+
from collections import defaultdict
89
# Numpy support
910
import numpy as np
1011

@@ -1270,6 +1271,220 @@ def _get_abs_layer_name(node):
12701271
params, num_layers)
12711272
return sym
12721273

1274+
# An internal list to contain all the control flow primitives used in Tensorflow
1275+
# 1.x.
1276+
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']
1277+
1278+
def _in_while_loop(control_flow_node_map, op_name):
1279+
"""
1280+
Check if a given control flow operator is part of a while loop execution
1281+
frame. This is based on the fact that there is only one occurrence of
1282+
`LoopCond` for a loop execution frame and it is only presented in the loop
1283+
construct.
1284+
1285+
Parameters
1286+
----------
1287+
control_flow_node_map : Dict[str, Set[str]]
1288+
A dictionay contains the unqiue control flow execution frame name to
1289+
a set of primitive operators mapping.
1290+
1291+
op_name : str
1292+
The name of a control flow primitive.
1293+
1294+
Returns
1295+
-------
1296+
ret : bool
1297+
Return true if the operator is in a while loop execution frame,
1298+
otherwise, return false.
1299+
"""
1300+
return op_name in control_flow_node_map and \
1301+
"LoopCond" in control_flow_node_map[op_name]
1302+
1303+
1304+
class Branch:
1305+
"""A class contains the components that are used to build up a Relay if
1306+
node.
1307+
1308+
Parameters
1309+
----------
1310+
cond : tvm.relay.Expr
1311+
The condition of a if node.
1312+
1313+
true_branch : tvm.relay.Expr
1314+
The body of the true branch of a if expression.
1315+
1316+
false_branch: tvm.relay.Expr
1317+
The body of the false branch of a if expression.
1318+
1319+
_if : tvm.relay.Expr
1320+
An internal variable indicates where an if expression is already created
1321+
for a matched TF condition construct.
1322+
1323+
Examples
1324+
--------
1325+
The following is a cond statement written in TensorFlow:
1326+
1327+
.. code-block:: python
1328+
1329+
def vanilla_cond():
1330+
i = tf.constant(1)
1331+
j = tf.constant(4)
1332+
1333+
def f1():
1334+
return tf.multiply(1, 17)
1335+
1336+
def f2():
1337+
return tf.add(4, 23)
1338+
r = tf.cond(tf.less(i, j), f1, f2)
1339+
1340+
This condition statement should be coverted into Relay in the following
1341+
form:
1342+
1343+
.. code-block:: python
1344+
1345+
fn (%Const: Tensor[(1,), int32],
1346+
%Const_1: Tensor[(1,), int32],
1347+
%cond/Mul/x: Tensor[(1,), int32],
1348+
%cond/Mul/y: Tensor[(1,), int32],
1349+
%cond/Add/x: Tensor[(1,), int32],
1350+
%cond/Add/y: Tensor[(1,), int32]) {
1351+
%0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool]
1352+
%1 = min(%0)
1353+
if (%1) {
1354+
%2 = multiply(%cond/Mul/x, %cond/Mul/y)
1355+
%2
1356+
} else {
1357+
%3 = add(%cond/Add/x, %cond/Add/y)
1358+
%3
1359+
}
1360+
}
1361+
"""
1362+
def __init__(self):
1363+
self._if = None
1364+
self.cond = None
1365+
self.true_branch = None
1366+
self.false_branch = None
1367+
1368+
def _if_node(self):
1369+
"""An internal API to create a relay if node from the matched TF
1370+
condition construct.
1371+
"""
1372+
# `cond` returns a tensor that contains boolean values. We add a `min`
1373+
# operator to checks if there is any false value. If so, this condition
1374+
# doesn't not hold.
1375+
cond = tvm.relay.op.min(self.cond)
1376+
return tvm.relay.If(cond, self.true_branch, self.false_branch)
1377+
1378+
def if_node(self):
1379+
"""Create an tvm.relay.If node if it hasn't been created yet."""
1380+
if self._if is None:
1381+
self._if = self._if_node()
1382+
return self._if
1383+
1384+
1385+
class Loop:
1386+
"""
1387+
A class contains the components that are used to build up a Relay
1388+
recursive call.
1389+
1390+
Parameters
1391+
----------
1392+
loop_vars : List[tvm.relay.Expr]
1393+
The loop variables that used in a while loop.
1394+
1395+
cond : tvm.relay.Expr
1396+
The condition of a while loop.
1397+
1398+
body : tvm.relay.Expr
1399+
The body of a matched while loop.
1400+
1401+
_loop : tvm.relay.Expr
1402+
An internal variable indicates where a recursive call is already created
1403+
for a matched TF while loop construct.
1404+
1405+
Examples
1406+
--------
1407+
The following is a vanilla loop from TensorFlow:
1408+
1409+
.. code-block:: python
1410+
1411+
i = tf.constant(0)
1412+
c = lambda i: tf.less(i, 10)
1413+
b = lambda i: tf.add(i, 1)
1414+
r = tf.while_loop(c, b, [i])
1415+
1416+
It will be converted to the following recursive call in Relay:
1417+
1418+
.. code-block:: python
1419+
1420+
fn (%while/Less/y: Tensor[(1,), int32],
1421+
%while/Add/y: Tensor[(1,), int32],
1422+
%Const: Tensor[(1,), int32]) {
1423+
%0 = fn(%loop_var0: Tensor[(1,), int32]) {
1424+
%1 = less(%loop_var0, %while/Less/y)
1425+
%2 = min(%1)
1426+
if (%2) {
1427+
%3 = add(%loop_var0, %while/Add/y)
1428+
free_var %while_loop
1429+
%4 = %while_loop(%3)
1430+
%4
1431+
} else {
1432+
%5 = (%loop_var0,)
1433+
%5
1434+
}
1435+
}
1436+
let %while_loop1 = %0
1437+
%6 = %while_loop1(%Const)
1438+
%6
1439+
}
1440+
"""
1441+
def __init__(self):
1442+
self.loop_vars = []
1443+
self.cond = None
1444+
self.body = []
1445+
self._loop = None
1446+
1447+
def _while_loop(self):
1448+
"""An internal API to create a Relay recurisve call for a matched TF
1449+
`while_loop` construct.
1450+
"""
1451+
wl = tvm.relay.var('while_loop')
1452+
1453+
sb = tvm.relay.scope_builder.ScopeBuilder()
1454+
1455+
loop_vars = []
1456+
bind_map = {}
1457+
for i, var in enumerate(self.loop_vars):
1458+
assert isinstance(var, _expr.Var), repr(var)
1459+
v = tvm.relay.var("loop_var" + str(i),
1460+
type_annotation=var.type_annotation)
1461+
loop_vars.append(v)
1462+
bind_map[var] = v
1463+
1464+
self.cond = tvm.relay.bind(self.cond, bind_map)
1465+
self.body = [tvm.relay.bind(b, bind_map) for b in self.body]
1466+
1467+
cond = tvm.relay.op.min(self.cond)
1468+
1469+
with sb.if_scope(cond):
1470+
sb.ret(wl(*self.body))
1471+
with sb.else_scope():
1472+
sb.ret(tvm.relay.Tuple(loop_vars))
1473+
1474+
loop_fn = tvm.relay.Function(loop_vars, sb.get())
1475+
sb = tvm.relay.scope_builder.ScopeBuilder()
1476+
sb.let(wl, loop_fn)
1477+
sb.ret(wl(*self.loop_vars))
1478+
return sb.get()
1479+
1480+
def while_loop(self):
1481+
"""Instantiate a while loop if it has not been created yet."""
1482+
if self._loop is None:
1483+
self._loop = self._while_loop()
1484+
return self._loop
1485+
return self._loop
1486+
1487+
12731488
class GraphProto(object):
12741489
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
12751490
Definition:
@@ -1284,6 +1499,8 @@ def __init__(self):
12841499
self._num_rnn_layer = False
12851500
self._outputs_are_0d = {}
12861501
self._input_shapes = {}
1502+
self._loops = {}
1503+
self._branches = {}
12871504

12881505
def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12891506
"""Construct relay nodes from tensorflow graph definition - GraphDef.
@@ -1332,7 +1549,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
13321549
raise NotImplementedError( \
13331550
"The following operators are not implemented: {}".format(missing_operators))
13341551

1552+
control_flow_node_map = defaultdict(set)
13351553
for node in graph.node:
1554+
node_name_prefix = node.name.rsplit('/', 1)[0]
1555+
control_flow_node_map[node_name_prefix].add(node.op)
13361556
if node.op == 'Placeholder':
13371557
if shape and node.name in shape:
13381558
self._input_shapes[node.name] = list(shape[node.name])
@@ -1447,12 +1667,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
14471667
# This means the node is 1d in Relay and 0d in TF.
14481668
# See `_expand_dims_0d_aware`.
14491669
if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
1450-
input_0d_mismatch.add(in_sym)
1670+
input_0d_mismatch.add(in_sym[0])
14511671

14521672
attr['_input_shapes'] = input_shapes
14531673
attr['_input_0d_mismatch'] = input_0d_mismatch
14541674

1455-
op = self._convert_operator(node.op, inputs, attr, graph)
1675+
if node.op in _control_flow_nodes:
1676+
op = self._convert_control_flow_operator(node, inputs,
1677+
attr,
1678+
control_flow_node_map)
1679+
else:
1680+
op = self._convert_operator(node.op, inputs, attr, graph)
14561681

14571682
# Check if op is converted to param
14581683
if isinstance(op, np.ndarray):
@@ -1493,7 +1718,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
14931718

14941719
out = []
14951720
if outputs is None:
1496-
out = op
1721+
if node.op == "Exit":
1722+
out = [op[0].tuple_value]
1723+
else:
1724+
out = op
14971725
else:
14981726
for out_name in outputs:
14991727
if ":" in out_name:
@@ -1529,7 +1757,9 @@ def _parse_import_prerequisites(self, graph):
15291757
elif node.op == "Const":
15301758
pass
15311759
else:
1532-
if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
1760+
if any([node.op in t for t in [_identity_list, _convert_map,
1761+
_convert_map_rnn,
1762+
_control_flow_nodes]]):
15331763
pass
15341764
else:
15351765
missing_operators.add(node.op)
@@ -1656,6 +1886,89 @@ def _convert_rnn_operator(self, op_name, inputs,
16561886
sym = self.rnn.process_op(op_name, inputs, attrs, params)
16571887
return sym
16581888

1889+
def _convert_control_flow_operator(self, node, inputs, attrs, control_flow_node_map):
1890+
"""
1891+
Convert the Relay control flow primitive into corresponding component
1892+
of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop`
1893+
are converted in Relay `If` and recusrive call, respectively.
1894+
1895+
Parameters
1896+
----------
1897+
node: TensorFlow graph node object.
1898+
A TensorFlow graph node object.
1899+
1900+
inputs : List[tvm.relay.Expr]
1901+
List of input symbols.
1902+
1903+
attrs : Dict[tvm.Attrs]
1904+
Dict of operator attributes.
1905+
1906+
control_flow_node_map : Dict[str, Set[str]]
1907+
A dictionary contains the execution frame name to primitives
1908+
mapping.
1909+
1910+
Returns
1911+
-------
1912+
op : tvm.relay.Expr
1913+
Converted relay expression.
1914+
"""
1915+
node_name_prefix = node.name.rsplit('/', 1)[0]
1916+
if node.op == "Merge":
1917+
if _in_while_loop(control_flow_node_map, node_name_prefix):
1918+
op = self._nodes[node.input[0]]
1919+
self._loops[node_name_prefix] = Loop()
1920+
else:
1921+
if len(self._branches) == 0:
1922+
raise RuntimeError("Cannot find a created "
1923+
"conditional for merge node")
1924+
branch = self._branches[node_name_prefix]
1925+
false_br = self._nodes[node.input[0]]
1926+
true_br = self._nodes[node.input[1]]
1927+
assert len(true_br) == 1
1928+
assert len(false_br) == 1
1929+
branch.true_branch = true_br[0]
1930+
branch.false_branch = false_br[0]
1931+
op = [branch.if_node()]
1932+
elif node.op == "Exit":
1933+
loop = self._loops[node_name_prefix]
1934+
exit_name = node.name.split('/')[-1]
1935+
assert str.startswith(exit_name, 'Exit')
1936+
1937+
# TensorFlow has differen naming convention on different
1938+
# versions.
1939+
if '_' in exit_name:
1940+
exit_number = int("0" + exit_name[5:])
1941+
else:
1942+
exit_number = int("0" + exit_name[4:])
1943+
1944+
expr = loop.while_loop()
1945+
op = _expr.TupleGetItem(expr, exit_number)
1946+
elif node.op == "Enter":
1947+
op = self._nodes[node.input[0]]
1948+
elif node.op == "LoopCond":
1949+
op = self._nodes[node.input[0]]
1950+
assert len(op) == 1
1951+
self._loops[node_name_prefix].cond = op[0]
1952+
elif node.op == "Switch":
1953+
op = self._nodes[node.input[0]]
1954+
assert len(op) == 1
1955+
if _in_while_loop(control_flow_node_map, node_name_prefix):
1956+
self._loops[node_name_prefix].loop_vars.append(op[0])
1957+
else:
1958+
if node_name_prefix not in self._branches:
1959+
self._branches[node_name_prefix] = Branch()
1960+
self._branches[node_name_prefix].cond = ir_pass.infer_type(op[0])
1961+
elif node.op == "NextIteration":
1962+
op = self._nodes[node.input[0]]
1963+
assert len(op) == 1
1964+
self._loops[node_name_prefix].body.append(op[0])
1965+
else:
1966+
raise Exception("Cannot identify control flow operator: " +
1967+
"{}".format(node.op))
1968+
1969+
return op
1970+
1971+
16591972
def _convert_operator(self, op_name, inputs, attrs,
16601973
graph, identity_list=None, convert_map=None):
16611974
"""Convert from Tensorflow operator to relay operator.

0 commit comments

Comments
 (0)