Skip to content

Commit 24877e8

Browse files
authored
Merge branch 'master' into fix_instance_norm
2 parents cea3239 + 58ace0a commit 24877e8

File tree

19 files changed

+848
-497
lines changed

19 files changed

+848
-497
lines changed

doc/support_status.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ Notes:
3535
|Celu|-|-|-|-|-|-|-|-|-|-|-|**12**|12|Celu|
3636
|Clip|**1**|1|1|1|1|**6**|6|6|6|6|**11**|**12**|**13**|Clip|
3737
|Compress|-|-|-|-|-|-|-|-|**9**|9|**11**|11|11|Compress|
38-
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|11|**13**:small_red_triangle:|Concat|
38+
|Concat|**1**|1|1|**4**|4|4|4|4|4|4|**11**|11|**13**|Concat|
3939
|ConcatFromSequence|-|-|-|-|-|-|-|-|-|-|**11**:small_orange_diamond:|11:small_orange_diamond:|11:small_orange_diamond:|ConcatFromSequence|
40-
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|**12**|**13**:small_red_triangle:|Constant|
40+
|Constant|**1**|1|1|1|1|1|1|1|**9**|9|**11**|**12**|**13**|Constant|
4141
|ConstantOfShape|-|-|-|-|-|-|-|-|**9**|9|9|9|9|ConstantOfShape|
4242
|Conv|**1**|1|1|1|1|1|1|1|1|1|**11**|11|11|Conv|
4343
|ConvInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|10|ConvInteger|
@@ -68,26 +68,26 @@ Notes:
6868
|GlobalAveragePool|**1**|1|1|1|1|1|1|1|1|1|1|1|1|GlobalAveragePool|
6969
|GlobalLpPool|**1**|**2**|2|2|2|2|2|2|2|2|2|2|2|GlobalLpPool|
7070
|GlobalMaxPool|**1**|1|1|1|1|1|1|1|1|1|1|1|1|GlobalMaxPool|
71-
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**:small_red_triangle:|Greater|
71+
|Greater|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**|Greater|
7272
|GreaterOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|GreaterOrEqual|
7373
|HardSigmoid|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|HardSigmoid|
7474
|Hardmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Hardmax|
7575
|Identity|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**:small_red_triangle:|Identity|
76-
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|If|
76+
|If|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|If|
7777
|InstanceNormalization|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|InstanceNormalization|
7878
|IsInf|-|-|-|-|-|-|-|-|-|**10**|10|10|10|IsInf|
7979
|IsNaN|-|-|-|-|-|-|-|-|**9**|9|9|9|**13**:small_red_triangle:|IsNaN|
8080
|LRN|**1**|1|1|1|1|1|1|1|1|1|1|1|**13**|LRN|
8181
|LSTM|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**7**:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|7:small_orange_diamond:|LSTM|
8282
|LeakyRelu|**1**|1|1|1|1|**6**|6|6|6|6|6|6|6|LeakyRelu|
83-
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**:small_red_triangle:|Less|
83+
|Less|**1**|1|1|1|1|1|**7**|7|**9**|9|9|9|**13**|Less|
8484
|LessOrEqual|-|-|-|-|-|-|-|-|-|-|-|**12**|12|LessOrEqual|
8585
|Log|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Log|
8686
|LogSoftmax|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|LogSoftmax|
87-
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Loop|
87+
|Loop|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**|Loop|
8888
|LpNormalization|**1**|1|1|1|1|1|1|1|1|1|1|1|1|LpNormalization|
8989
|LpPool|**1**|**2**|2|2|2|2|2|2|2|2|**11**|11|11|LpPool|
90-
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|9|**13**:small_red_triangle:|MatMul|
90+
|MatMul|**1**|1|1|1|1|1|1|1|**9**|9|9|9|**13**|MatMul|
9191
|MatMulInteger|-|-|-|-|-|-|-|-|-|**10**|10|10|10|MatMulInteger|
9292
|Max|**1**|1|1|1|1|**6**|6|**8**|8|8|8|**12**|**13**|Max|
9393
|MaxPool|**1**:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|1:small_orange_diamond:|**8**:small_orange_diamond:|8:small_orange_diamond:|**10**:small_orange_diamond:|**11**:small_orange_diamond:|**12**:small_orange_diamond:|12:small_orange_diamond:|MaxPool|
@@ -164,7 +164,7 @@ Notes:
164164
|Sqrt|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Sqrt|
165165
|Squeeze|**1**|1|1|1|1|1|1|1|1|1|**11**|11|**13**:small_red_triangle:|Squeeze|
166166
|StringNormalizer|-|-|-|-|-|-|-|-|-|**10**:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|10:small_red_triangle:|StringNormalizer|
167-
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**:small_red_triangle:|Sub|
167+
|Sub|**1**|1|1|1|1|**6**|**7**|7|7|7|7|7|**13**|Sub|
168168
|Sum|**1**|1|1|1|1|**6**|6|**8**|8|8|8|8|**13**:small_red_triangle:|Sum|
169169
|Tan|-|-|-|-|-|-|**7**|7|7|7|7|7|7|Tan|
170170
|Tanh|**1**|1|1|1|1|**6**|6|6|6|6|6|6|**13**|Tanh|
@@ -179,7 +179,7 @@ Notes:
179179
|Where|-|-|-|-|-|-|-|-|**9**|9|9|9|9|Where|
180180
|Xor|**1**|1|1|1|1|1|**7**|7|7|7|7|7|7|Xor|
181181

182-
ONNX-TF Supported Operators / ONNX Operators: 103 / 162
182+
ONNX-TF Supported Operators / ONNX Operators: 105 / 162
183183

184184
Notes:
185185
1. Cast: Cast string to data types other than float32/float64/int32/int64 is not supported in Tensorflow

onnx_tf/backend.py

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,12 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
114114
# initialized: A list of names of the initialized tensors.
115115

116116
if graph_def.initializer:
117-
input_dict_items = cls._onnx_initializer_to_input_dict_items(
118-
graph_def.initializer)
119117
initialized = {init.name for init in graph_def.initializer}
120118
else:
121-
input_dict_items = []
122119
initialized = set()
123120

121+
input_dict = dict()
122+
124123
module = BackendTFModule(handlers, opset, strict, graph_def, cls)
125124
signatures = dict()
126125
for value_info in graph_def.input:
@@ -146,7 +145,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
146145
shape=shape
147146
) if value_info.name not in input_tensor_dict else input_tensor_dict[
148147
value_info.name]
149-
input_dict_items.append((value_info_name, x))
148+
input_dict[value_info.name] = x
150149

151150
tf_rep = TensorflowRep()
152151
tf_rep.inputs = [
@@ -159,8 +158,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
159158
tf_rep.tf_module = module
160159
tf_rep.signatures = signatures
161160
tf_rep.tensor_dict = module.gen_tensor_dict(
162-
input_dict_items) if gen_tensor_dict else None
163-
161+
input_dict) if gen_tensor_dict else None
164162
return tf_rep
165163

166164
@classmethod
@@ -288,55 +286,30 @@ def supports_device(cls, device):
288286
@classmethod
289287
def onnx_graph_to_tensorflow_ops(cls,
290288
subgraph,
291-
input_values,
292289
tensor_dict,
293290
opset=None,
294291
strict=True):
295292
"""
296293
Converts ONNX graph to Tensorflow operations
297294
Args:
298-
subgraph: the ONNX graph to be converted
299-
input_values: dictionary with values/tensors to initialize
300-
the subgraph inputs. if the subgraph.input
301-
are send in as parameters then it is required,
302-
otherwise this can be empty dictionary
303-
tensor_dict: the dictionary that contain values for all the
304-
node.inputs in the subgraph that are not defined
305-
in the subgraph or input_values.
295+
subgraph: the ONNX graph to be converted.
296+
tensor_dict: tensor dict of the subgraph.
306297
opset: opset version of the operator set.
307298
strict: whether to enforce semantic equivalence between the
308299
original model and the converted tensorflow model,
309300
defaults to True (yes, enforce semantic equivalence).
310301
Returns:
311302
array of Tensorflow Tensors
312303
"""
313-
# get the subgraph.input from input_values
314-
subgraph_tensor_dict = input_values.copy()
315-
# get the rest of the subgraph input from tensor_dict
316-
for i in subgraph.input:
317-
if i.name not in subgraph_tensor_dict.keys():
318-
subgraph_tensor_dict[i.name] = tensor_dict[i.name]
319-
# get the required initializer constant node(s) for the subgraph
320-
# Need to get the initializer constant nodes from tensor_dict here
321-
# because input from initializer will not be send in as inputs
322-
# to the subgraph and those nodes are not in the subgraph
323-
nodes_outputs = []
324-
for node in subgraph.node:
325-
for o_name in node.output:
326-
nodes_outputs.append(o_name)
327304
for node in subgraph.node:
328-
for i_name in node.input:
329-
if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict.keys(
330-
):
331-
subgraph_tensor_dict[i_name] = tensor_dict[i_name]
332305
onnx_node = OnnxNode(node)
333306
output_ops = cls._onnx_node_to_tensorflow_op(onnx_node,
334-
subgraph_tensor_dict,
307+
tensor_dict,
335308
opset=opset,
336309
strict=strict)
337310
curr_node_output_map = dict(zip(onnx_node.outputs, output_ops))
338-
subgraph_tensor_dict.update(curr_node_output_map)
339-
return subgraph_tensor_dict
311+
tensor_dict.update(curr_node_output_map)
312+
return tensor_dict
340313

341314
@classmethod
342315
def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True, **kwargs):

onnx_tf/backend_tf_module.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,32 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
1313
self.backend = backend
1414
self.outputs = []
1515

16+
# get initializer from the main graph and all subgraphs in loop or if or scan
17+
# into tensor_dict
18+
def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
19+
if graph.initializer:
20+
graph_tensor_dict.update(
21+
self.backend._onnx_initializer_to_input_dict_items(graph.initializer))
22+
for node in graph.node:
23+
if node.op_type in ['Loop', 'Scan']:
24+
onnx_node = OnnxNode(node)
25+
body = onnx_node.attrs["body"]
26+
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
27+
body, graph_tensor_dict)
28+
elif node.op_type == 'If':
29+
onnx_node = OnnxNode(node)
30+
then_branch = onnx_node.attrs['then_branch']
31+
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
32+
then_branch, graph_tensor_dict)
33+
else_branch = onnx_node.attrs['else_branch']
34+
graph_tensor_dict = self._get_initializer_from_graph_and_subgraphs(
35+
else_branch, graph_tensor_dict)
36+
return graph_tensor_dict
37+
1638
@tf.function
17-
def gen_tensor_dict(self, input_dict_items):
18-
tensor_dict = dict(input_dict_items)
39+
def gen_tensor_dict(self, input_dict):
40+
tensor_dict = self._get_initializer_from_graph_and_subgraphs(
41+
self.graph_def, dict(input_dict))
1942

2043
for node in self.graph_def.node:
2144
onnx_node = OnnxNode(node)
@@ -31,15 +54,8 @@ def gen_tensor_dict(self, input_dict_items):
3154

3255
@tf.function
3356
def __call__(self, **kwargs):
34-
tensor_dict = kwargs
35-
36-
if self.graph_def.initializer:
37-
input_dict_items = self.backend._onnx_initializer_to_input_dict_items(
38-
self.graph_def.initializer)
39-
else:
40-
input_dict_items = []
41-
42-
tensor_dict.update(input_dict_items)
57+
tensor_dict = self._get_initializer_from_graph_and_subgraphs(
58+
self.graph_def, kwargs)
4359

4460
for node in self.graph_def.node:
4561
onnx_node = OnnxNode(node)

onnx_tf/handlers/backend/concat.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ def version_4(cls, node, **kwargs):
2525
@classmethod
2626
def version_11(cls, node, **kwargs):
2727
return cls._common(node, **kwargs)
28+
29+
@classmethod
30+
def version_13(cls, node, **kwargs):
31+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/constant.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,7 @@ def version_12(cls, node, **kwargs):
7171
inputs=[value],
7272
attrs={"dtype": dtype})
7373
]
74+
75+
@classmethod
76+
def version_13(cls, node, **kwargs):
77+
return cls.version_12(node, **kwargs)

onnx_tf/handlers/backend/greater.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ def version_7(cls, node, **kwargs):
2121
@classmethod
2222
def version_9(cls, node, **kwargs):
2323
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
24+
25+
@classmethod
26+
def version_13(cls, node, **kwargs):
27+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

onnx_tf/handlers/backend/if.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,19 @@ def _common(cls, node, **kwargs):
2121
def true_fn():
2222
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
2323
subgraph=then_branch,
24-
input_values={}, # all inputs of then_branch are in tensor_dict
25-
tensor_dict=kwargs["tensor_dict"],
24+
tensor_dict=dict(kwargs["tensor_dict"]),
2625
opset=current_opset)
2726
return [subgraph_tensor_dict[o.name] for o in then_branch.output]
2827

2928
def false_fn():
3029
subgraph_tensor_dict = onnx_tf.backend.onnx_graph_to_tensorflow_ops(
3130
subgraph=else_branch,
32-
input_values={}, # all inputs of else_branch are in tensor_dict
33-
tensor_dict=kwargs["tensor_dict"],
31+
tensor_dict=dict(kwargs["tensor_dict"]),
3432
opset=current_opset)
3533
return [subgraph_tensor_dict[o.name] for o in else_branch.output]
3634

37-
return [
38-
cls.make_tensor_from_onnx_node(node, inputs=[cond, true_fn, false_fn])
39-
]
35+
return cls.make_tensor_from_onnx_node(node,
36+
inputs=[cond, true_fn, false_fn])
4037

4138
@classmethod
4239
def version_1(cls, node, **kwargs):
@@ -45,3 +42,7 @@ def version_1(cls, node, **kwargs):
4542
@classmethod
4643
def version_11(cls, node, **kwargs):
4744
return cls._common(node, **kwargs)
45+
46+
@classmethod
47+
def version_13(cls, node, **kwargs):
48+
return cls._common(node, **kwargs)

onnx_tf/handlers/backend/less.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ def version_7(cls, node, **kwargs):
2121
@classmethod
2222
def version_9(cls, node, **kwargs):
2323
return [cls.make_tensor_from_onnx_node(node, **kwargs)]
24+
25+
@classmethod
26+
def version_13(cls, node, **kwargs):
27+
return [cls.make_tensor_from_onnx_node(node, **kwargs)]

0 commit comments

Comments
 (0)