5
5
6
6
import logging
7
7
import warnings
8
+ from collections import defaultdict
8
9
# Numpy support
9
10
import numpy as np
10
11
@@ -1270,6 +1271,220 @@ def _get_abs_layer_name(node):
1270
1271
params , num_layers )
1271
1272
return sym
1272
1273
1274
+ # An internal list to contain all the control flow primitives used in Tensorflow
1275
+ # 1.x.
1276
+ _control_flow_nodes = ['Merge' , 'Switch' , 'NextIteration' , 'Exit' , 'Enter' , 'LoopCond' ]
1277
+
1278
+ def _in_while_loop (control_flow_node_map , op_name ):
1279
+ """
1280
+ Check if a given control flow operator is part of a while loop execution
1281
+ frame. This is based on the fact that there is only one occurrence of
1282
+ `LoopCond` for a loop execution frame and it is only presented in the loop
1283
+ construct.
1284
+
1285
+ Parameters
1286
+ ----------
1287
+ control_flow_node_map : Dict[str, Set[str]]
1288
+ A dictionay contains the unqiue control flow execution frame name to
1289
+ a set of primitive operators mapping.
1290
+
1291
+ op_name : str
1292
+ The name of a control flow primitive.
1293
+
1294
+ Returns
1295
+ -------
1296
+ ret : bool
1297
+ Return true if the operator is in a while loop execution frame,
1298
+ otherwise, return false.
1299
+ """
1300
+ return op_name in control_flow_node_map and \
1301
+ "LoopCond" in control_flow_node_map [op_name ]
1302
+
1303
+
1304
+ class Branch :
1305
+ """A class contains the components that are used to build up a Relay if
1306
+ node.
1307
+
1308
+ Parameters
1309
+ ----------
1310
+ cond : tvm.relay.Expr
1311
+ The condition of a if node.
1312
+
1313
+ true_branch : tvm.relay.Expr
1314
+ The body of the true branch of a if expression.
1315
+
1316
+ false_branch: tvm.relay.Expr
1317
+ The body of the false branch of a if expression.
1318
+
1319
+ _if : tvm.relay.Expr
1320
+ An internal variable indicates where an if expression is already created
1321
+ for a matched TF condition construct.
1322
+
1323
+ Examples
1324
+ --------
1325
+ The following is a cond statement written in TensorFlow:
1326
+
1327
+ .. code-block:: python
1328
+
1329
+ def vanilla_cond():
1330
+ i = tf.constant(1)
1331
+ j = tf.constant(4)
1332
+
1333
+ def f1():
1334
+ return tf.multiply(1, 17)
1335
+
1336
+ def f2():
1337
+ return tf.add(4, 23)
1338
+ r = tf.cond(tf.less(i, j), f1, f2)
1339
+
1340
+ This condition statement should be coverted into Relay in the following
1341
+ form:
1342
+
1343
+ .. code-block:: python
1344
+
1345
+ fn (%Const: Tensor[(1,), int32],
1346
+ %Const_1: Tensor[(1,), int32],
1347
+ %cond/Mul/x: Tensor[(1,), int32],
1348
+ %cond/Mul/y: Tensor[(1,), int32],
1349
+ %cond/Add/x: Tensor[(1,), int32],
1350
+ %cond/Add/y: Tensor[(1,), int32]) {
1351
+ %0 = less(%Const, %Const_1) # ty=Tensor[(1,), bool]
1352
+ %1 = min(%0)
1353
+ if (%1) {
1354
+ %2 = multiply(%cond/Mul/x, %cond/Mul/y)
1355
+ %2
1356
+ } else {
1357
+ %3 = add(%cond/Add/x, %cond/Add/y)
1358
+ %3
1359
+ }
1360
+ }
1361
+ """
1362
+ def __init__ (self ):
1363
+ self ._if = None
1364
+ self .cond = None
1365
+ self .true_branch = None
1366
+ self .false_branch = None
1367
+
1368
+ def _if_node (self ):
1369
+ """An internal API to create a relay if node from the matched TF
1370
+ condition construct.
1371
+ """
1372
+ # `cond` returns a tensor that contains boolean values. We add a `min`
1373
+ # operator to checks if there is any false value. If so, this condition
1374
+ # doesn't not hold.
1375
+ cond = tvm .relay .op .min (self .cond )
1376
+ return tvm .relay .If (cond , self .true_branch , self .false_branch )
1377
+
1378
+ def if_node (self ):
1379
+ """Create an tvm.relay.If node if it hasn't been created yet."""
1380
+ if self ._if is None :
1381
+ self ._if = self ._if_node ()
1382
+ return self ._if
1383
+
1384
+
1385
+ class Loop :
1386
+ """
1387
+ A class contains the components that are used to build up a Relay
1388
+ recursive call.
1389
+
1390
+ Parameters
1391
+ ----------
1392
+ loop_vars : List[tvm.relay.Expr]
1393
+ The loop variables that used in a while loop.
1394
+
1395
+ cond : tvm.relay.Expr
1396
+ The condition of a while loop.
1397
+
1398
+ body : tvm.relay.Expr
1399
+ The body of a matched while loop.
1400
+
1401
+ _loop : tvm.relay.Expr
1402
+ An internal variable indicates where a recursive call is already created
1403
+ for a matched TF while loop construct.
1404
+
1405
+ Examples
1406
+ --------
1407
+ The following is a vanilla loop from TensorFlow:
1408
+
1409
+ .. code-block:: python
1410
+
1411
+ i = tf.constant(0)
1412
+ c = lambda i: tf.less(i, 10)
1413
+ b = lambda i: tf.add(i, 1)
1414
+ r = tf.while_loop(c, b, [i])
1415
+
1416
+ It will be converted to the following recursive call in Relay:
1417
+
1418
+ .. code-block:: python
1419
+
1420
+ fn (%while/Less/y: Tensor[(1,), int32],
1421
+ %while/Add/y: Tensor[(1,), int32],
1422
+ %Const: Tensor[(1,), int32]) {
1423
+ %0 = fn(%loop_var0: Tensor[(1,), int32]) {
1424
+ %1 = less(%loop_var0, %while/Less/y)
1425
+ %2 = min(%1)
1426
+ if (%2) {
1427
+ %3 = add(%loop_var0, %while/Add/y)
1428
+ free_var %while_loop
1429
+ %4 = %while_loop(%3)
1430
+ %4
1431
+ } else {
1432
+ %5 = (%loop_var0,)
1433
+ %5
1434
+ }
1435
+ }
1436
+ let %while_loop1 = %0
1437
+ %6 = %while_loop1(%Const)
1438
+ %6
1439
+ }
1440
+ """
1441
+ def __init__ (self ):
1442
+ self .loop_vars = []
1443
+ self .cond = None
1444
+ self .body = []
1445
+ self ._loop = None
1446
+
1447
+ def _while_loop (self ):
1448
+ """An internal API to create a Relay recurisve call for a matched TF
1449
+ `while_loop` construct.
1450
+ """
1451
+ wl = tvm .relay .var ('while_loop' )
1452
+
1453
+ sb = tvm .relay .scope_builder .ScopeBuilder ()
1454
+
1455
+ loop_vars = []
1456
+ bind_map = {}
1457
+ for i , var in enumerate (self .loop_vars ):
1458
+ assert isinstance (var , _expr .Var ), repr (var )
1459
+ v = tvm .relay .var ("loop_var" + str (i ),
1460
+ type_annotation = var .type_annotation )
1461
+ loop_vars .append (v )
1462
+ bind_map [var ] = v
1463
+
1464
+ self .cond = tvm .relay .bind (self .cond , bind_map )
1465
+ self .body = [tvm .relay .bind (b , bind_map ) for b in self .body ]
1466
+
1467
+ cond = tvm .relay .op .min (self .cond )
1468
+
1469
+ with sb .if_scope (cond ):
1470
+ sb .ret (wl (* self .body ))
1471
+ with sb .else_scope ():
1472
+ sb .ret (tvm .relay .Tuple (loop_vars ))
1473
+
1474
+ loop_fn = tvm .relay .Function (loop_vars , sb .get ())
1475
+ sb = tvm .relay .scope_builder .ScopeBuilder ()
1476
+ sb .let (wl , loop_fn )
1477
+ sb .ret (wl (* self .loop_vars ))
1478
+ return sb .get ()
1479
+
1480
+ def while_loop (self ):
1481
+ """Instantiate a while loop if it has not been created yet."""
1482
+ if self ._loop is None :
1483
+ self ._loop = self ._while_loop ()
1484
+ return self ._loop
1485
+ return self ._loop
1486
+
1487
+
1273
1488
class GraphProto (object ):
1274
1489
""" A helper class for handling relay graph copying from Tensorflow GraphDef.
1275
1490
Definition:
@@ -1284,6 +1499,8 @@ def __init__(self):
1284
1499
self ._num_rnn_layer = False
1285
1500
self ._outputs_are_0d = {}
1286
1501
self ._input_shapes = {}
1502
+ self ._loops = {}
1503
+ self ._branches = {}
1287
1504
1288
1505
def from_tensorflow (self , graph , layout = "NHWC" , shape = None , outputs = None ):
1289
1506
"""Construct relay nodes from tensorflow graph definition - GraphDef.
@@ -1332,7 +1549,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1332
1549
raise NotImplementedError ( \
1333
1550
"The following operators are not implemented: {}" .format (missing_operators ))
1334
1551
1552
+ control_flow_node_map = defaultdict (set )
1335
1553
for node in graph .node :
1554
+ node_name_prefix = node .name .rsplit ('/' , 1 )[0 ]
1555
+ control_flow_node_map [node_name_prefix ].add (node .op )
1336
1556
if node .op == 'Placeholder' :
1337
1557
if shape and node .name in shape :
1338
1558
self ._input_shapes [node .name ] = list (shape [node .name ])
@@ -1447,12 +1667,17 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1447
1667
# This means the node is 1d in Relay and 0d in TF.
1448
1668
# See `_expand_dims_0d_aware`.
1449
1669
if self ._outputs_are_0d [node_name ][tensor_slot ] and input_shape :
1450
- input_0d_mismatch .add (in_sym )
1670
+ input_0d_mismatch .add (in_sym [ 0 ] )
1451
1671
1452
1672
attr ['_input_shapes' ] = input_shapes
1453
1673
attr ['_input_0d_mismatch' ] = input_0d_mismatch
1454
1674
1455
- op = self ._convert_operator (node .op , inputs , attr , graph )
1675
+ if node .op in _control_flow_nodes :
1676
+ op = self ._convert_control_flow_operator (node , inputs ,
1677
+ attr ,
1678
+ control_flow_node_map )
1679
+ else :
1680
+ op = self ._convert_operator (node .op , inputs , attr , graph )
1456
1681
1457
1682
# Check if op is converted to param
1458
1683
if isinstance (op , np .ndarray ):
@@ -1493,7 +1718,10 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1493
1718
1494
1719
out = []
1495
1720
if outputs is None :
1496
- out = op
1721
+ if node .op == "Exit" :
1722
+ out = [op [0 ].tuple_value ]
1723
+ else :
1724
+ out = op
1497
1725
else :
1498
1726
for out_name in outputs :
1499
1727
if ":" in out_name :
@@ -1529,7 +1757,9 @@ def _parse_import_prerequisites(self, graph):
1529
1757
elif node .op == "Const" :
1530
1758
pass
1531
1759
else :
1532
- if any ([node .op in t for t in [_identity_list , _convert_map , _convert_map_rnn ]]):
1760
+ if any ([node .op in t for t in [_identity_list , _convert_map ,
1761
+ _convert_map_rnn ,
1762
+ _control_flow_nodes ]]):
1533
1763
pass
1534
1764
else :
1535
1765
missing_operators .add (node .op )
@@ -1656,6 +1886,89 @@ def _convert_rnn_operator(self, op_name, inputs,
1656
1886
sym = self .rnn .process_op (op_name , inputs , attrs , params )
1657
1887
return sym
1658
1888
1889
+ def _convert_control_flow_operator (self , node , inputs , attrs , control_flow_node_map ):
1890
+ """
1891
+ Convert the Relay control flow primitive into corresponding component
1892
+ of a Relay control flow construct, i.e. `tf.cond` and `tf.while_loop`
1893
+ are converted in Relay `If` and recusrive call, respectively.
1894
+
1895
+ Parameters
1896
+ ----------
1897
+ node: TensorFlow graph node object.
1898
+ A TensorFlow graph node object.
1899
+
1900
+ inputs : List[tvm.relay.Expr]
1901
+ List of input symbols.
1902
+
1903
+ attrs : Dict[tvm.Attrs]
1904
+ Dict of operator attributes.
1905
+
1906
+ control_flow_node_map : Dict[str, Set[str]]
1907
+ A dictionary contains the execution frame name to primitives
1908
+ mapping.
1909
+
1910
+ Returns
1911
+ -------
1912
+ op : tvm.relay.Expr
1913
+ Converted relay expression.
1914
+ """
1915
+ node_name_prefix = node .name .rsplit ('/' , 1 )[0 ]
1916
+ if node .op == "Merge" :
1917
+ if _in_while_loop (control_flow_node_map , node_name_prefix ):
1918
+ op = self ._nodes [node .input [0 ]]
1919
+ self ._loops [node_name_prefix ] = Loop ()
1920
+ else :
1921
+ if len (self ._branches ) == 0 :
1922
+ raise RuntimeError ("Cannot find a created "
1923
+ "conditional for merge node" )
1924
+ branch = self ._branches [node_name_prefix ]
1925
+ false_br = self ._nodes [node .input [0 ]]
1926
+ true_br = self ._nodes [node .input [1 ]]
1927
+ assert len (true_br ) == 1
1928
+ assert len (false_br ) == 1
1929
+ branch .true_branch = true_br [0 ]
1930
+ branch .false_branch = false_br [0 ]
1931
+ op = [branch .if_node ()]
1932
+ elif node .op == "Exit" :
1933
+ loop = self ._loops [node_name_prefix ]
1934
+ exit_name = node .name .split ('/' )[- 1 ]
1935
+ assert str .startswith (exit_name , 'Exit' )
1936
+
1937
+ # TensorFlow has differen naming convention on different
1938
+ # versions.
1939
+ if '_' in exit_name :
1940
+ exit_number = int ("0" + exit_name [5 :])
1941
+ else :
1942
+ exit_number = int ("0" + exit_name [4 :])
1943
+
1944
+ expr = loop .while_loop ()
1945
+ op = _expr .TupleGetItem (expr , exit_number )
1946
+ elif node .op == "Enter" :
1947
+ op = self ._nodes [node .input [0 ]]
1948
+ elif node .op == "LoopCond" :
1949
+ op = self ._nodes [node .input [0 ]]
1950
+ assert len (op ) == 1
1951
+ self ._loops [node_name_prefix ].cond = op [0 ]
1952
+ elif node .op == "Switch" :
1953
+ op = self ._nodes [node .input [0 ]]
1954
+ assert len (op ) == 1
1955
+ if _in_while_loop (control_flow_node_map , node_name_prefix ):
1956
+ self ._loops [node_name_prefix ].loop_vars .append (op [0 ])
1957
+ else :
1958
+ if node_name_prefix not in self ._branches :
1959
+ self ._branches [node_name_prefix ] = Branch ()
1960
+ self ._branches [node_name_prefix ].cond = ir_pass .infer_type (op [0 ])
1961
+ elif node .op == "NextIteration" :
1962
+ op = self ._nodes [node .input [0 ]]
1963
+ assert len (op ) == 1
1964
+ self ._loops [node_name_prefix ].body .append (op [0 ])
1965
+ else :
1966
+ raise Exception ("Cannot identify control flow operator: " +
1967
+ "{}" .format (node .op ))
1968
+
1969
+ return op
1970
+
1971
+
1659
1972
def _convert_operator (self , op_name , inputs , attrs ,
1660
1973
graph , identity_list = None , convert_map = None ):
1661
1974
"""Convert from Tensorflow operator to relay operator.
0 commit comments