@@ -114,13 +114,12 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
114114 # initialized: A list of names of the initialized tensors.
115115
116116 if graph_def .initializer :
117- input_dict_items = cls ._onnx_initializer_to_input_dict_items (
118- graph_def .initializer )
119117 initialized = {init .name for init in graph_def .initializer }
120118 else :
121- input_dict_items = []
122119 initialized = set ()
123120
121+ input_dict = dict ()
122+
124123 module = BackendTFModule (handlers , opset , strict , graph_def , cls )
125124 signatures = dict ()
126125 for value_info in graph_def .input :
@@ -146,7 +145,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
146145 shape = shape
147146 ) if value_info .name not in input_tensor_dict else input_tensor_dict [
148147 value_info .name ]
149- input_dict_items . append (( value_info_name , x ))
148+ input_dict [ value_info . name ] = x
150149
151150 tf_rep = TensorflowRep ()
152151 tf_rep .inputs = [
@@ -159,8 +158,7 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict, **kwargs):
159158 tf_rep .tf_module = module
160159 tf_rep .signatures = signatures
161160 tf_rep .tensor_dict = module .gen_tensor_dict (
162- input_dict_items ) if gen_tensor_dict else None
163-
161+ input_dict ) if gen_tensor_dict else None
164162 return tf_rep
165163
166164 @classmethod
@@ -288,55 +286,30 @@ def supports_device(cls, device):
288286 @classmethod
289287 def onnx_graph_to_tensorflow_ops (cls ,
290288 subgraph ,
291- input_values ,
292289 tensor_dict ,
293290 opset = None ,
294291 strict = True ):
295292 """
296293 Converts ONNX graph to Tensorflow operations
297294 Args:
298- subgraph: the ONNX graph to be converted
299- input_values: dictionary with values/tensors to initialize
300- the subgraph inputs. if the subgraph.input
301- are send in as parameters then it is required,
302- otherwise this can be empty dictionary
303- tensor_dict: the dictionary that contain values for all the
304- node.inputs in the subgraph that are not defined
305- in the subgraph or input_values.
295+ subgraph: the ONNX graph to be converted.
296+ tensor_dict: tensor dict of the subgraph.
306297 opset: opset version of the operator set.
307298 strict: whether to enforce semantic equivalence between the
308299 original model and the converted tensorflow model,
309300 defaults to True (yes, enforce semantic equivalence).
310301 Returns:
311302 array of Tensorflow Tensors
312303 """
313- # get the subgraph.input from input_values
314- subgraph_tensor_dict = input_values .copy ()
315- # get the rest of the subgraph input from tensor_dict
316- for i in subgraph .input :
317- if i .name not in subgraph_tensor_dict .keys ():
318- subgraph_tensor_dict [i .name ] = tensor_dict [i .name ]
319- # get the required initializer constant node(s) for the subgraph
320- # Need to get the initializer constant nodes from tensor_dict here
321- # because input from initializer will not be send in as inputs
322- # to the subgraph and those nodes are not in the subgraph
323- nodes_outputs = []
324- for node in subgraph .node :
325- for o_name in node .output :
326- nodes_outputs .append (o_name )
327304 for node in subgraph .node :
328- for i_name in node .input :
329- if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict .keys (
330- ):
331- subgraph_tensor_dict [i_name ] = tensor_dict [i_name ]
332305 onnx_node = OnnxNode (node )
333306 output_ops = cls ._onnx_node_to_tensorflow_op (onnx_node ,
334- subgraph_tensor_dict ,
307+ tensor_dict ,
335308 opset = opset ,
336309 strict = strict )
337310 curr_node_output_map = dict (zip (onnx_node .outputs , output_ops ))
338- subgraph_tensor_dict .update (curr_node_output_map )
339- return subgraph_tensor_dict
311+ tensor_dict .update (curr_node_output_map )
312+ return tensor_dict
340313
341314 @classmethod
342315 def onnx_graph_to_tensorflow_rep (cls , graph_def , strict = True , ** kwargs ):
0 commit comments