@@ -64,12 +64,12 @@ def prepare(cls,
6464 super (TensorflowBackend , cls ).prepare (model , device , ** kwargs )
6565 common .logger .setLevel (logging_level )
6666 common .logger .handlers [0 ].setLevel (logging_level )
67- common .sys_config .auto_cast = auto_cast
67+ common .sys_config .auto_cast = auto_cast
6868
69- return cls .onnx_model_to_tensorflow_rep (model , strict )
69+ return cls .onnx_model_to_tensorflow_rep (model , strict , ** kwargs )
7070
7171 @classmethod
72- def onnx_model_to_tensorflow_rep (cls , model , strict ):
72+ def onnx_model_to_tensorflow_rep (cls , model , strict , ** kwargs ):
7373 """ Convert ONNX model to TensorflowRep.
7474
7575 :param model: ONNX ModelProto object.
@@ -86,45 +86,68 @@ def onnx_model_to_tensorflow_rep(cls, model, strict):
8686 opset_import = [make_opsetid (defs .ONNX_DOMAIN , 1 )]
8787 else :
8888 opset_import = model .opset_import
89- return cls ._onnx_graph_to_tensorflow_rep (model .graph , opset_import , strict )
89+ return cls ._onnx_graph_to_tensorflow_rep (model .graph , opset_import , strict ,
90+ ** kwargs )
9091
9192 @classmethod
92- def _onnx_graph_to_tensorflow_rep (cls , graph_def , opset , strict ):
93+ def _onnx_graph_to_tensorflow_rep (cls , graph_def , opset , strict , ** kwargs ):
9394 """ Convert ONNX graph to TensorflowRep.
9495
9596 :param graph_def: ONNX GraphProto object.
9697 :param opset: ONNX OperatorSetIdProto list.
9798 :param strict: whether to enforce semantic equivalence between the original model
9899 and the converted tensorflow model.
100+ :kwargs: additional arguements to generate tensor_dict for model debugging
99101 :return: TensorflowRep object.
100102 """
103+ # To generate tensor_dict or not, default is False
104+ gen_tensor_dict = kwargs [
105+ 'gen_tensor_dict' ] if 'gen_tensor_dict' in kwargs else False
106+ # User provided input tensors, in the case the model inputs have unknown shapes
107+ input_tensor_dict = kwargs [
108+ 'input_tensor_dict' ] if 'input_tensor_dict' in kwargs else dict ()
109+
101110 handlers = cls ._get_handlers (opset )
102111
103112 # initializer: TensorProtos representing the values to initialize
104113 # a given tensor.
105114 # initialized: A list of names of the initialized tensors.
106115
107116 if graph_def .initializer :
117+ input_dict_items = cls ._onnx_initializer_to_input_dict_items (
118+ graph_def .initializer )
108119 initialized = {init .name for init in graph_def .initializer }
109120 else :
121+ input_dict_items = []
110122 initialized = set ()
111123
112124 module = BackendTFModule (handlers , opset , strict , graph_def , cls )
113125 signatures = dict ()
114-
115126 for value_info in graph_def .input :
116127 if value_info .name in initialized :
117128 continue
118129 shape = list (
119130 d .dim_value if (d .dim_value > 0 and d .dim_param == "" ) else None
120131 for d in value_info .type .tensor_type .shape .dim )
121132 value_info_name = value_info .name .replace (
122- ":" , "_tf_" ) + "_" + get_unique_suffix (
123- ) if ":" in value_info .name else value_info .name
133+ ":" , "_tf_" ) + "_" + get_unique_suffix (
134+ ) if ":" in value_info .name else value_info .name
124135
125- tf_spec = tf .TensorSpec (shape , data_type .onnx2tf (value_info .type .tensor_type .elem_type ), value_info_name )
136+ tf_spec = tf .TensorSpec (
137+ shape , data_type .onnx2tf (value_info .type .tensor_type .elem_type ),
138+ value_info_name )
126139 signatures [value_info .name ] = tf_spec
127140
141+ if gen_tensor_dict :
142+ x = tf .constant (
143+ 0 ,
144+ dtype = data_type .onnx2tf (value_info .type .tensor_type .elem_type ),
145+ name = value_info_name ,
146+ shape = shape
147+ ) if value_info .name not in input_tensor_dict else input_tensor_dict [
148+ value_info .name ]
149+ input_dict_items .append ((value_info_name , x ))
150+
128151 tf_rep = TensorflowRep ()
129152 tf_rep .inputs = [
130153 value_info .name
@@ -135,6 +158,9 @@ def _onnx_graph_to_tensorflow_rep(cls, graph_def, opset, strict):
135158 module .outputs = tf_rep .outputs
136159 tf_rep .tf_module = module
137160 tf_rep .signatures = signatures
161+ tf_rep .tensor_dict = module .gen_tensor_dict (
162+ input_dict_items ) if gen_tensor_dict else None
163+
138164 return tf_rep
139165
140166 @classmethod
@@ -148,7 +174,9 @@ def run_node(cls, node, inputs, device='CPU', outputs_info=None, **kwargs):
148174 :param kwargs: Other args.
149175 :return: Outputs.
150176 """
177+
151178 class TFModule (tf .Module ):
179+
152180 def __init__ (self , node ):
153181 super (TFModule , self ).__init__ ()
154182 self .node = node
@@ -171,13 +199,16 @@ def __call__(self, **input_dict):
171199 feed_dict_raw = dict (zip (node .inputs , inputs ))
172200
173201 # TODO: is constant the best way for feeding inputs?
174- input_dict = dict (
175- [( x [ 0 ], tf . constant ( x [ 1 ])) for x in feed_dict_raw . items () ])
202+ input_dict = dict ([( x [ 0 ], tf . constant ( x [ 1 ])) for x in feed_dict_raw . items ()
203+ ])
176204
177205 module = TFModule (node )
178206
179207 output_vals = module (** input_dict )
180- output_vals = [val .numpy () if isinstance (val , tf .Tensor ) else val for val in output_vals ]
208+ output_vals = [
209+ val .numpy () if isinstance (val , tf .Tensor ) else val
210+ for val in output_vals
211+ ]
181212
182213 return namedtupledict ('Outputs' , node .outputs )(* output_vals )
183214
@@ -231,11 +262,13 @@ def _onnx_node_to_tensorflow_op(cls,
231262 """
232263 handlers = handlers or cls ._get_handlers (opset )
233264 if handlers :
234- handler = handlers [node .domain ].get (node .op_type , None ) if node .domain in handlers else None
265+ handler = handlers [node .domain ].get (
266+ node .op_type , None ) if node .domain in handlers else None
235267 if handler :
236268 return handler .handle (node , tensor_dict = tensor_dict , strict = strict )
237269
238- raise BackendIsNotSupposedToImplementIt ("{} is not implemented." .format (node .op_type ))
270+ raise BackendIsNotSupposedToImplementIt ("{} is not implemented." .format (
271+ node .op_type ))
239272
240273 @classmethod
241274 def _get_handlers (cls , opset ):
@@ -293,7 +326,8 @@ def onnx_graph_to_tensorflow_ops(cls,
293326 nodes_outputs .append (o_name )
294327 for node in subgraph .node :
295328 for i_name in node .input :
296- if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict .keys ():
329+ if i_name not in nodes_outputs and i_name not in subgraph_tensor_dict .keys (
330+ ):
297331 subgraph_tensor_dict [i_name ] = tensor_dict [i_name ]
298332 onnx_node = OnnxNode (node )
299333 output_ops = cls ._onnx_node_to_tensorflow_op (onnx_node ,
@@ -305,7 +339,7 @@ def onnx_graph_to_tensorflow_ops(cls,
305339 return subgraph_tensor_dict
306340
307341 @classmethod
308- def onnx_graph_to_tensorflow_rep (cls , graph_def , strict = True ):
342+ def onnx_graph_to_tensorflow_rep (cls , graph_def , strict = True , ** kwargs ):
309343 """
310344 Converts ONNX graph to TensorflowRep
311345 Args:
@@ -318,7 +352,7 @@ def onnx_graph_to_tensorflow_rep(cls, graph_def, strict=True):
318352 """
319353 # get the opset of the installed ONNX
320354 opset = [make_opsetid (defs .ONNX_DOMAIN , defs .onnx_opset_version ())]
321- return cls ._onnx_graph_to_tensorflow_rep (graph_def , opset , strict )
355+ return cls ._onnx_graph_to_tensorflow_rep (graph_def , opset , strict , ** kwargs )
322356
323357
324358prepare = TensorflowBackend .prepare
0 commit comments