Skip to content

Commit 88a679a

Browse files
authored
Merge pull request #87 from iksnagreb/fix/remove_identity_ops
Fix RemoveIdentityOps not correctly handling ops following fork-nodes
2 parents e02f701 + 2d09341 commit 88a679a

File tree

3 files changed

+67
-24
lines changed

3 files changed

+67
-24
lines changed

src/qonnx/core/modelwrapper.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -429,14 +429,24 @@ def is_fork_node(self, node):
429429
"""Checks if the given node is a fork, that is, the node has multiple
430430
direct successors"""
431431
direct_successors = self.find_direct_successors(node)
432-
is_fork = False if direct_successors is None else (len(direct_successors) > 1)
432+
# if the node output is also wired to a top-level output, it is still
433+
# a fork with only 1 direct successor
434+
if node.output[0] in [x.name for x in self.graph.output]:
435+
is_fork = False if direct_successors is None else (len(direct_successors) > 0)
436+
else:
437+
is_fork = False if direct_successors is None else (len(direct_successors) > 1)
433438
return is_fork
434439

435440
def is_join_node(self, node):
436441
"""Checks if the given node is a join, that is, the node has multiple
437442
direct predecessors"""
438443
direct_predecessors = self.find_direct_predecessors(node)
439-
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1)
444+
# if the node input is also wired to a top-level input, it is still
445+
# a fork with only 1 direct predecessor
446+
if node.input[0] in [x.name for x in self.graph.input]:
447+
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 0)
448+
else:
449+
is_join = False if direct_predecessors is None else (len(direct_predecessors) > 1)
440450
return is_join
441451

442452
def get_all_tensor_names(self):

src/qonnx/transformation/remove.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28-
29-
3028
import numpy as np
29+
import warnings
3130

3231
from qonnx.core.modelwrapper import ModelWrapper
3332
from qonnx.transformation.base import Transformation
@@ -58,21 +57,43 @@ def apply(self, model: ModelWrapper):
5857

5958

6059
def remove_node_and_rewire(model, node):
60+
# Currently cannot remove and rewire join-nodes, probably not necessary to
61+
# support this
62+
if model.is_join_node(node):
63+
# Log this as a warning, so the user is aware of this, there might be
64+
# somthing wrong or some checks missing at the caller site
65+
warnings.warn("Removing join-node operation is currently not supported")
66+
# Exit the function here without doing anything
67+
return
68+
# We already know that node is not a join-node, thus to rewire, we only need
69+
# to check the single producer
6170
producer = model.find_producer(node.input[0])
62-
if producer is not None:
63-
# wire output tensor to
64-
# output of producer node
71+
# If there is a producer which is not a fork-node, rewiring is simple
72+
if producer is not None and not model.is_fork_node(producer):
73+
# Rewire by skipping the node, letting the producer directly feed the
74+
# nodes output.
75+
# TODO: Check whether this already covers fork-node identities?
6576
producer.output[0] = node.output[0]
77+
# If there is no producer or the producer forks, rewiring is a bit more
78+
# complicated
6679
else:
67-
# node is first in graph
80+
# Now it depends on the successor nodes to rewire their inputs
6881
successors = model.find_direct_successors(node)
82+
# Singular node detached from the rest of the graph?
6983
assert successors is not None, "Whole graph is one node."
70-
for succ in successors:
71-
for i, s_inp in enumerate(succ.input):
84+
# We need to rewire the input of each successor to not detach parts of
85+
# the graph
86+
for successor in successors:
87+
# Find the inputs of the successor which are produced by the node to
88+
# be removed
89+
for i, s_inp in enumerate(successor.input):
90+
# Note: This might happen multiple times?
7291
if s_inp == node.output[0]:
73-
# rewire successor's input directly to graph input
74-
succ.input[i] = node.input[0]
75-
# remove node
92+
# Rewire successor's input directly to nodes input
93+
# Note: Node may not be a join-node, but there is probably
94+
# no such thing as join-node identity anyway
95+
successor.input[i] = node.input[0]
96+
# Remove node
7697
model.graph.node.remove(node)
7798

7899

@@ -117,5 +138,9 @@ def apply(self, model):
117138
remove_node_and_rewire(model, n)
118139
graph_modified = True
119140
break
141+
elif n.op_type == "Identity":
142+
remove_node_and_rewire(model, n)
143+
graph_modified = True
144+
break
120145
model = model.transform(InferShapes())
121146
return (model, graph_modified)

tests/transformation/test_remove_identity_ops.py

+19-11
Original file line numberDiff line numberDiff line change
@@ -51,28 +51,34 @@ def insert_identity_op(model, op, as_first_node, approx):
5151
val = np.asarray([zero_val], dtype=np.float32)
5252
elif op in ["Mul", "Div"]:
5353
val = np.asarray([one_val], dtype=np.float32)
54+
elif op in ["Identity"]:
55+
val = None
5456
else:
5557
return
5658

5759
graph = model.graph
60+
if val is None:
61+
inplist = ["inp" if as_first_node else "div_out"]
62+
else:
63+
model.set_initializer("value", val)
64+
inplist = ["inp" if as_first_node else "div_out", "value"]
65+
identity_node = helper.make_node(op, inplist, ["ident_out"])
5866
if as_first_node:
59-
identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"])
6067
graph.node.insert(0, identity_node)
6168
graph.node[1].input[0] = "ident_out"
6269
else:
63-
identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"])
6470
graph.node.insert(3, identity_node)
6571
graph.node[-1].input[0] = "ident_out"
66-
model.set_initializer("value", val)
6772

6873
return model
6974

7075

7176
# identity operations to be inserted
72-
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
77+
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
7378
@pytest.mark.parametrize("approx", [False, True])
7479
@pytest.mark.parametrize("as_first_node", [False, True])
75-
def test_remove_identity_ops(op, as_first_node, approx):
80+
@pytest.mark.parametrize("fork_before_id", [False, True])
81+
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
7682
# set up onnx model
7783
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
7884
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [])
@@ -109,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
109115
model = model.transform(InferShapes())
110116
model = model.transform(InferDataTypes())
111117
idict = {"inp": inp_values}
112-
odict = oxe.execute_onnx(model, idict)
113-
out_before = odict["outp"]
118+
odict_before = oxe.execute_onnx(model, idict)
114119
num_of_nodes_before = len(model.graph.node)
115-
120+
if fork_before_id and not as_first_node:
121+
divout_vi = model.get_tensor_valueinfo("div_out")
122+
model.graph.output.append(divout_vi)
123+
model.graph.value_info.remove(divout_vi)
116124
model = model.transform(RemoveIdentityOps())
117125
num_of_nodes_after = len(model.graph.node)
118126
assert num_of_nodes_before - 1 == num_of_nodes_after
119127

120-
odict = oxe.execute_onnx(model, idict)
121-
out_after = odict["outp"]
122-
assert np.isclose(out_before, out_after, atol=1e-3).all()
128+
odict_after = oxe.execute_onnx(model, idict)
129+
outputs_same = [np.isclose(odict_before[tname], odict_after[tname], atol=1e-3).all() for tname in odict_before.keys()]
130+
assert all(outputs_same)

0 commit comments

Comments
 (0)