11import tensorflow as tf
2+ from onnx_tf .common import exception
3+ from onnx_tf .common import get_variable_name
24from onnx_tf .pb_wrapper import OnnxNode
35
46
57class BackendTFModule (tf .Module ):
8+ """ BackendTFModule is the tf.Module class used in backend.prepare,
9+ tf_rep.export_graph and tf_rep.run
10+ """
611
712 def __init__ (self , handlers , opset , strict , graph_def , backend ):
813 super (BackendTFModule , self ).__init__ ()
@@ -14,6 +19,8 @@ def __init__(self, handlers, opset, strict, graph_def, backend):
1419 self .outputs = []
1520 self .initializer_dict = self ._get_initializer_from_graph_and_subgraphs (
1621 self .graph_def , dict ())
22+ self .handler_variables = self ._create_handlers_variables (
23+ self .graph_def , dict ())
1724
1825 # get initializer from the main graph and all subgraphs in loop or if or scan
1926 # into tensor_dict
@@ -37,10 +44,43 @@ def _get_initializer_from_graph_and_subgraphs(self, graph, graph_tensor_dict):
3744 else_branch , graph_tensor_dict )
3845 return graph_tensor_dict
3946
47+ # create tf.Variable for handlers that required to use variable in handler
48+ def _create_handlers_variables (self , graph , vars_dict ):
49+ if self .handlers :
50+ handlers = self .backend ._get_handlers (self .opset )
51+ for node in graph .node :
52+ handler = handlers [node .domain ].get (
53+ node .op_type , None ) if node .domain in handlers else None
54+ if handler and bool (
55+ handler .get_req_vars_template (node , self .initializer_dict )):
56+ for v_name , v_template in handler .get_req_vars_template (
57+ node , self .initializer_dict ).items ():
58+ v_init , v_shape = v_template
59+ v_name = get_variable_name (node , v_name )
60+ if v_name in vars_dict .keys ():
61+ # found duplicated variable name due to non unique node name
62+ exception .NON_UNIQUE_NODE_NAME_EXCEPT ()
63+ vars_dict [v_name ] = tf .Variable (v_init ,
64+ dtype = v_init .dtype ,
65+ shape = v_shape ,
66+ name = v_name )
67+ if node .op_type in ['Loop' , 'Scan' ]:
68+ onnx_node = OnnxNode (node )
69+ body = onnx_node .attrs ["body" ]
70+ vars_dict = self ._create_handlers_variables (body , vars_dict )
71+ elif node .op_type == 'If' :
72+ onnx_node = OnnxNode (node )
73+ then_branch = onnx_node .attrs ['then_branch' ]
74+ vars_dict = self ._create_handlers_variables (then_branch , vars_dict )
75+ else_branch = onnx_node .attrs ['else_branch' ]
76+ vars_dict = self ._create_handlers_variables (else_branch , vars_dict )
77+ return vars_dict
78+
4079 @tf .function
4180 def gen_tensor_dict (self , input_dict ):
4281 tensor_dict = dict (input_dict )
4382 tensor_dict .update (self .initializer_dict )
83+ tensor_dict .update (self .handler_variables )
4484
4585 for node in self .graph_def .node :
4686 onnx_node = OnnxNode (node )
@@ -58,6 +98,7 @@ def gen_tensor_dict(self, input_dict):
5898 def __call__ (self , ** kwargs ):
5999 tensor_dict = kwargs
60100 tensor_dict .update (self .initializer_dict )
101+ tensor_dict .update (self .handler_variables )
61102
62103 for node in self .graph_def .node :
63104 onnx_node = OnnxNode (node )
@@ -70,4 +111,41 @@ def __call__(self, **kwargs):
70111 tensor_dict .update (curr_node_output_map )
71112
72113 outputs = [tensor_dict [output ] for output in self .outputs ]
114+
115+ return outputs
116+
117+
118+ class TFModule (tf .Module ):
119+ """ TFModule is the tf.Module class used in backend.run_node.
120+ """
121+
122+ def __init__ (self , node , backend ):
123+ super (TFModule , self ).__init__ ()
124+ self .node = node
125+ self .backend = backend
126+ self .handlers = backend ._get_handlers (opset = None )
127+ self .handler_variables = self ._create_handlers_variables (dict ())
128+
129+ def _create_handlers_variables (self , vars_dict ):
130+ if self .handlers :
131+ handler = self .handlers [self .node .domain ].get (
132+ self .node .op_type ,
133+ None ) if self .node .domain in self .handlers else None
134+ if handler and bool (
135+ handler .get_req_vars_template (self .node , self .node .attrs )):
136+ for v_name , v_template in handler .get_req_vars_template (
137+ self .node , self .node .attrs ).items ():
138+ v_init , v_shape = v_template
139+ v_name = get_variable_name (self .node , v_name )
140+ vars_dict [v_name ] = tf .Variable (v_init ,
141+ dtype = v_init .dtype ,
142+ shape = v_shape ,
143+ name = v_name )
144+ return vars_dict
145+
146+ @tf .function
147+ def __call__ (self , ** input_dict ):
148+ input_dict .update (self .handler_variables )
149+ outputs = self .backend ._onnx_node_to_tensorflow_op (self .node , input_dict ,
150+ self .handlers )
73151 return outputs
0 commit comments