Skip to content

Commit 58ace0a

Browse files
Fix Loop and If handlers (#753)
1. Fix Loop and If handlers to run the new onnx_backend_test 2. Add initializers from all the subgraphs into tensor_dict 3. Update If, Loop and Scan to make a copy of tensor_dict and add in all the inputs and outputs of the subgraph into this tensor_dict copy and then send it to backend.onnx_graph_to_tensorflow_ops 4. Update backend.onnx_graph_to_tensorflow_ops to eliminate tensor_dict and just use the subgraph_tensor_dict from IF, Loop and Scan 5. Fix Min and Max to support inputs that are in different shapes 6. Add dynamic input support to SequenceInsert 7. Add Opset 13 support for If, Loop, Concat, Constant, Matmul, Sub 8. Fix issue #742 Signed-off-by: Winnie Tsang <[email protected]> Co-authored-by: Chin Huang <[email protected]>
1 parent af4e07d commit 58ace0a

File tree

19 files changed

+848
-497
lines changed

19 files changed

+848
-497
lines changed

doc/support_status.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ Notes:
3535
|Celu|-|-|-|-|-|-|-|-|-|-|-|**12**|12|Celu|
3636
|Clip|**1**|1|1|1|1|**6**|6|6|6|6|**11**|**12**|**13**|Clip|
3737
|Compress|-|-|-|-|-|-|-|-|**9**|9|**11**|11|11|Compress|
38-
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|11|**13**:small_red_triangle:|Concat|
38+
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|11|**13**|Concat|
3939
|ConcatFromSequence|-|-|-|-|-|-|-|-|-|-|**11**:small_orange_diamond:|11:small_orange_diamond:|11:small_orange_diamond:|ConcatFromSequence|
40-
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|**12**|**13**:small_red_triangle:|Constant|
40+
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|**12**|**13**|Constant|
4141
|ConstantOfShape|-|-|-|-|-|-|-|-|**9**|9|9|9|9|ConstantOfShape|
4242
|Conv|**1**|1|1|1|1|1|1|1|1|1|**11**|11|11|Conv|
4343
|ConvInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|10|ConvInteger|
@@ -68,26 +68,26 @@ Notes:
6868
|GlobalAveragePool|**1**|1|1|1|1|1|1|1|1|1|1|1|1|GlobalAveragePool|
6969
|GlobalLpPool|**1**|**2**|2|2|2|2|2|2|2|2|2|2|2|GlobalLpPool|
7070
|GlobalMaxPool|**1**|1|1|1|1|1|1|1|1|1|1|1|1|GlobalMaxPool|
71-
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**:small_red_triangle:|Greater|
71+
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**|Greater|
7272
|GreaterOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|GreaterOrEqual|
7373
|HardSigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|HardSigmoid|
7474
|Hardmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Hardmax|
7575
|Identity|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**:small_red_triangle:|Identity|
76-
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|If|
76+
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|If|
7777
|InstanceNormalization|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|InstanceNormalization|
7878
|IsInf|-|-|-|-|-|-|-|-|-|**10**|10|10|10|IsInf|
7979
|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**:small_red_triangle:|IsNaN|
8080
|LRN|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**|LRN|
8181
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|LSTM|
8282
|LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|LeakyRelu|
83-
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**:small_red_triangle:|Less|
83+
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**|Less|
8484
|LessOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|LessOrEqual|
8585
|Log|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Log|
8686
|LogSoftmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|LogSoftmax|
87-
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Loop|
87+
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|Loop|
8888
|LpNormalization|**1**|1|1|1|1|1|1|1|1|1|1|1|1|LpNormalization|
8989
|LpPool|**1**|**2**|2|2|2|2|2|2|2|2|**11**|11|11|LpPool|
90-
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|9|**13**:small_red_triangle:|MatMul|
90+
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|9|**13**|MatMul|
9191
|MatMulInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|10|MatMulInteger|
9292
|Max|**1**|1|1|1|1|**6**|6|**8**|8|8|8|**12**|**13**|Max|
9393
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_orange_diamond:|**11**:small_orange_diamond:|**12**:small_orange_diamond:|12:small_orange_diamond:|MaxPool|
@@ -164,7 +164,7 @@ Notes:
164164
|Sqrt|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Sqrt|
165165
|Squeeze|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Squeeze|
166166
|StringNormalizer|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|StringNormalizer|
167-
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**:small_red_triangle:|Sub|
167+
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**|Sub|
168168
|Sum|**1**|1|1|1|1|**6**|6|**8**|8|8|8|8|**13**:small_red_triangle:|Sum|
169169
|Tan|-|-|-|-|-|-|**7**|7|7|7|7|7|7|Tan|
170170
|Tanh|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Tanh|
@@ -179,7 +179,7 @@ Notes:
179179
|Where|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Where|
180180
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|7|7|Xor|
181181

182-
ONNX-TF Supported Operators / ONNX Operators: 103 / 162
182+
ONNX-TF Supported Operators / ONNX Operators: 105 / 162
183183

184184
Notes:
185185
1. Cast: Cast string to data types other than float32/float64/int32/int64 is not supported in Tensorflow

onnx_tf/backend.py

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,12 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
114114
# initialized: A list of names of the initialized tensors.
115115

116116
if graph_def.initializer:
117-
input_dict_items = cls._onnx_initializer_to_input_dict_items(
118-
graph_def.initializer)
119117
initialized = {init.name for init in graph_def.initializer}
120118
else:
121-
input_dict_items = []
122119
initialized = set()
123120

121+
input_dict = dict()
122+
124123
module = BackendTFModule(handlers, opset, strict, graph_def, cls)
125124
signatures = dict()
126125
for value_info in graph_def.input:
@@ -146,7 +145,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
146145
shape=shape
147146
) if value_info.name not in input_tensor_dict else input_tensor_dict[
148147
value_info.name]
149-
input_dict_items.append((value_info_name, x))
148+
input_dict[value_info.name] = x
150149

151150
tf_rep = TensorflowRep()
152151
tf_rep.inputs = [
@@ -159,8 +158,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
159158
tf_rep.tf_module = module
160159
tf_rep.signatures = signatures
161160
tf_rep.tensor_dict = module.gen_tensor_dict(
162-
input_dict_items) if gen_tensor_dict else None
163-
161+
input_dict) if gen_tensor_dict else None
164162
return tf_rep
165163

166164
@classmethod
@@ -288,55 +286,30 @@ def supports_device(cls, device):
288286
@classmethod
289287
def onnx_graph_to_tensorflow_ops(cls,
290288
subgraph,
291-
input_values,
292289
tensor_dict,
293290
opset=None,
294291
strict=True):
295292
"""
296293
Converts ONNX graph to Tensorflow operations
297294
Args:
298-
subgraph: the ONNX graph to be converted
299-
input_values: dictionary with values/tensors to initialize
300-
the subgraph inputs. if the subgraph.input
301-
are send in as parameters then it is required,
302-
otherwise this can be empty dictionary
303-
tensor_dict: the dictionary that contain values for all the
304-
node.inputs in the subgraph that are not defined
305-
in the subgraph or input_values.
295+
subgraph: the ONNX graph to be converted.
296+
tensor_dict: tensor dict of the subgraph.
306297
opset: opset version of the operator set.
307298
strict: whether to enforce semantic equivalence between the
308299
original model and the converted tensorflow model,
309300
defaults to True (yes, enforce semantic equivalence).
310301
Returns:
311302
array of Tensorflow Tensors
312303
"""
313-
# get the subgraph.input from input_values
314-
subgraph_tensor_dict = input_values.copy()
315-
# get the rest of the subgraph input from tensor_dict
316-
for i in subgraph.input:
317-
if i.name not in subgraph_tensor_dict.keys():
318-
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
319-
# get the required initializer constant node(s) for the subgraph
320-
# Need to get the initializer constant nodes from tensor_dict here
321-
# because input from initializer will not be send in as inputs
322-
# to the subgraph and those nodes are not in the subgraph
323-
nodes_outputs = []
324-
for node in subgraph.node:
325-
for o_name in node.output:
326-
nodes_outputs.append(o_name)
327304
for node in subgraph.node:
328-
for i_name in node.input:
329-
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
330-
):
331-
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
332305
onnx_node = OnnxNode(node)
333306
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
334-
subgraph_tensor_dict,
307+
tensor_dict,
335308
opset=opset,
336309
strict=strict)
337310
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
338-
subgraph_tensor_dict.update(curr_node_output_map)
339-
return subgraph_tensor_dict
311+
tensor_dict.update(curr_node_output_map)
312+
return tensor_dict
340313

341314
@classmethod
342315
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):

onnx_tf/backend_tf_module.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,32 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
1313
self.backend = backend
1414
self.outputs = []
1515

16+
# get initializer from the main graph and all subgraphs in loop or if or scan
17+
# into tensor_dict
18+
def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
19+
if graph.initializer:
20+
graph_tensor_dict.update(
21+
self.backend._onnx_initializer_to_input_dict_items(graph.initializer))
22+
for node in graph.node:
23+
if node.op_type in ['Loop', 'Scan']:
24+
onnx_node = OnnxNode(node)
25+
body = onnx_node.attrs["body"]
26+
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
27+
body, graph_tensor_dict)
28+
elif node.op_type == 'If':
29+
onnx_node = OnnxNode(node)
30+
then_branch = onnx_node.attrs['then_branch']
31+
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
32+
then_branch, graph_tensor_dict)
33+
else_branch = onnx_node.attrs['else_branch']
34+
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
35+
else_branch, graph_tensor_dict)
36+
return graph_tensor_dict
37+
1638
@tf.function
17-
def gen_tensor_dict(self, input_dict_items):
18-
tensor_dict = dict(input_dict_items)
39+
def gen_tensor_dict(self, input_dict):
40+
tensor_dict = self._get_initializer_from_graph_and_subgraphs(
41+
self.graph_def, dict(input_dict))
1942

2043
for node in self.graph_def.node:
2144
onnx_node = OnnxNode(node)
@@ -31,15 +54,8 @@ def gen_tensor_dict(self, input_dict_items):
3154

3255
@tf.function
3356
def __call__(self, **kwargs):
34-
tensor_dict = kwargs
35-
36-
if self.graph_def.initializer:
37-
input_dict_items = self.backend._onnx_initializer_to_input_dict_items(
38-
self.graph_def.initializer)
39-
else:
40-
input_dict_items = []
41-
42-
tensor_dict.update(input_dict_items)
57+
tensor_dict = self._get_initializer_from_graph_and_subgraphs(
58+
self.graph_def, kwargs)
4359

4460
for node in self.graph_def.node:
4561
onnx_node = OnnxNode(node)

onnx_tf/handlers/backend/concat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ def version_4(cls, node, **kwargs):
2525
@classmethod
2626
def version_11(cls, node, **kwargs):
2727
return cls._common(node, **kwargs)
28+
29+
@classmethod
30+
def version_13(cls, node, **kwargs):
31+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/constant.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,7 @@ def version_12(cls, node, **kwargs):
7171
inputs=[value],
7272
attrs={"dtype": dtype})
7373
]
74+
75+
@classmethod
76+
def version_13(cls, node, **kwargs):
77+
return cls.version_12(node, **kwargs)

onnx_tf/handlers/backend/greater.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ def version_7(cls, node, **kwargs):
2121
@classmethod
2222
def version_9(cls, node, **kwargs):
2323
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
24+
25+
@classmethod
26+
def version_13(cls, node, **kwargs):
27+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

onnx_tf/handlers/backend/if.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,19 @@ def _common(cls, node, **kwargs):
2121
def true_fn():
2222
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
2323
subgraph=then_branch,
24-
input_values={}, # all inputs of then_branch are in tensor_dict
25-
tensor_dict=kwargs["tensor_dict"],
24+
tensor_dict=dict(kwargs["tensor_dict"]),
2625
opset=current_opset)
2726
return [subgraph_tensor_dict[o.name] for o in then_branch.output]
2827

2928
def false_fn():
3029
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
3130
subgraph=else_branch,
32-
input_values={}, # all inputs of else_branch are in tensor_dict
33-
tensor_dict=kwargs["tensor_dict"],
31+
tensor_dict=dict(kwargs["tensor_dict"]),
3432
opset=current_opset)
3533
return [subgraph_tensor_dict[o.name] for o in else_branch.output]
3634

37-
return [
38-
cls.make_tensor_from_onnx_node(node, inputs=[cond, true_fn, false_fn])
39-
]
35+
return cls.make_tensor_from_onnx_node(node,
36+
inputs=[cond, true_fn, false_fn])
4037

4138
@classmethod
4239
def version_1(cls, node, **kwargs):
@@ -45,3 +42,7 @@ def version_1(cls, node, **kwargs):
4542
@classmethod
4643
def version_11(cls, node, **kwargs):
4744
return cls._common(node, **kwargs)
45+
46+
@classmethod
47+
def version_13(cls, node, **kwargs):
48+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/less.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ def version_7(cls, node, **kwargs):
2121
@classmethod
2222
def version_9(cls, node, **kwargs):
2323
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
24+
25+
@classmethod
26+
def version_13(cls, node, **kwargs):
27+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

0 commit comments

Comments
 (0)