@@ -51,28 +51,34 @@ def insert_identity_op(model, op, as_first_node, approx):
51
51
val = np .asarray ([zero_val ], dtype = np .float32 )
52
52
elif op in ["Mul" , "Div" ]:
53
53
val = np .asarray ([one_val ], dtype = np .float32 )
54
+ elif op in ["Identity" ]:
55
+ val = None
54
56
else :
55
57
return
56
58
57
59
graph = model .graph
60
+ if val is None :
61
+ inplist = ["inp" if as_first_node else "div_out" ]
62
+ else :
63
+ model .set_initializer ("value" , val )
64
+ inplist = ["inp" if as_first_node else "div_out" , "value" ]
65
+ identity_node = helper .make_node (op , inplist , ["ident_out" ])
58
66
if as_first_node :
59
- identity_node = helper .make_node (op , ["inp" , "value" ], ["ident_out" ])
60
67
graph .node .insert (0 , identity_node )
61
68
graph .node [1 ].input [0 ] = "ident_out"
62
69
else :
63
- identity_node = helper .make_node (op , ["div_out" , "value" ], ["ident_out" ])
64
70
graph .node .insert (3 , identity_node )
65
71
graph .node [- 1 ].input [0 ] = "ident_out"
66
- model .set_initializer ("value" , val )
67
72
68
73
return model
69
74
70
75
71
76
# identity operations to be inserted
72
- @pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" ])
77
+ @pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" , "Identity" ])
73
78
@pytest .mark .parametrize ("approx" , [False , True ])
74
79
@pytest .mark .parametrize ("as_first_node" , [False , True ])
75
- def test_remove_identity_ops (op , as_first_node , approx ):
80
+ @pytest .mark .parametrize ("fork_before_id" , [False , True ])
81
+ def test_remove_identity_ops (op , as_first_node , approx , fork_before_id ):
76
82
# set up onnx model
77
83
inp = helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , [1 , 4 , 1 , 1 ])
78
84
mul = helper .make_tensor_value_info ("mul" , TensorProto .FLOAT , [])
@@ -109,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
109
115
model = model .transform (InferShapes ())
110
116
model = model .transform (InferDataTypes ())
111
117
idict = {"inp" : inp_values }
112
- odict = oxe .execute_onnx (model , idict )
113
- out_before = odict ["outp" ]
118
+ odict_before = oxe .execute_onnx (model , idict )
114
119
num_of_nodes_before = len (model .graph .node )
115
-
120
+ if fork_before_id and not as_first_node :
121
+ divout_vi = model .get_tensor_valueinfo ("div_out" )
122
+ model .graph .output .append (divout_vi )
123
+ model .graph .value_info .remove (divout_vi )
116
124
model = model .transform (RemoveIdentityOps ())
117
125
num_of_nodes_after = len (model .graph .node )
118
126
assert num_of_nodes_before - 1 == num_of_nodes_after
119
127
120
- odict = oxe .execute_onnx (model , idict )
121
- out_after = odict [ "outp" ]
122
- assert np . isclose ( out_before , out_after , atol = 1e-3 ). all ()
128
+ odict_after = oxe .execute_onnx (model , idict )
129
+ outputs_same = [ np . isclose ( odict_before [ tname ], odict_after [ tname ], atol = 1e-3 ). all () for tname in odict_before . keys () ]
130
+ assert all (outputs_same )
0 commit comments