[Relax][ONNX] Complete ShapeExpr reshape handling in ONNX frontend#18956
[Relax][ONNX] Complete ShapeExpr reshape handling in ONNX frontend#18956tlopex merged 1 commit intoapache:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the ONNX frontend to improve the handling of Reshape operations when applied to relax.ShapeExpr inputs. It ensures that identity flattens are preserved to maintain specialized shape handling, while other reshape targets are correctly converted to tensor operations. Corresponding tests were added and updated to verify these changes, including symbolic shape deduction. Feedback was provided to refactor the test node construction for better readability and conciseness.
| shape_node = helper.make_node("Shape", ["data"], ["shape_output"]) | ||
| gather_node = helper.make_node("Gather", ["shape_output", "indices"], ["gather_output"]) | ||
| nodes = [index_node, shape_node] | ||
| gather_input = "shape_output" | ||
|
|
||
| if with_reshape_flatten: | ||
| reshape_node = helper.make_node( | ||
| "Reshape", ["shape_output", "target_shape"], ["reshaped_shape"] | ||
| ) | ||
| nodes.append(reshape_node) | ||
| gather_input = "reshaped_shape" | ||
|
|
||
| gather_node = helper.make_node("Gather", [gather_input, "indices"], ["gather_output"]) | ||
| unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"]) | ||
| constant_of_shape_node = helper.make_node( | ||
| "ConstantOfShape", | ||
| ["unsqueeze_output"], | ||
| ["output"], | ||
| value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]), | ||
| ) | ||
| nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node]) |
There was a problem hiding this comment.
For improved readability and conciseness, the construction of the nodes list can be refactored. Creating the nodes directly within nodes.append and nodes.extend calls avoids unnecessary intermediate variables, making the graph's structure more direct and easier to understand.
| shape_node = helper.make_node("Shape", ["data"], ["shape_output"]) | |
| gather_node = helper.make_node("Gather", ["shape_output", "indices"], ["gather_output"]) | |
| nodes = [index_node, shape_node] | |
| gather_input = "shape_output" | |
| if with_reshape_flatten: | |
| reshape_node = helper.make_node( | |
| "Reshape", ["shape_output", "target_shape"], ["reshaped_shape"] | |
| ) | |
| nodes.append(reshape_node) | |
| gather_input = "reshaped_shape" | |
| gather_node = helper.make_node("Gather", [gather_input, "indices"], ["gather_output"]) | |
| unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"]) | |
| constant_of_shape_node = helper.make_node( | |
| "ConstantOfShape", | |
| ["unsqueeze_output"], | |
| ["output"], | |
| value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]), | |
| ) | |
| nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node]) | |
| shape_node = helper.make_node("Shape", ["data"], ["shape_output"]) | |
| nodes = [index_node, shape_node] | |
| gather_input = "shape_output" | |
| if with_reshape_flatten: | |
| nodes.append( | |
| helper.make_node("Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]) | |
| ) | |
| gather_input = "reshaped_shape" | |
| nodes.extend([ | |
| helper.make_node("Gather", [gather_input, "indices"], ["gather_output"]), | |
| helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"]), | |
| helper.make_node( | |
| "ConstantOfShape", | |
| ["unsqueeze_output"], | |
| ["output"], | |
| value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]), | |
| ), | |
| ]) |
Summary
Complete
Reshapehandling for shape values in the Relax ONNX frontend.Changes
ShapeExpr -> Reshape([-1])on the shape-specialized pathShapeExprto anint64tensor for other reshape targets and apply regular tensor reshape semanticsShape -> Reshape([-1])[-1]targets such as[1, 3]and[3, 1]Shape -> Reshape([-1]) -> Gather -> Unsqueezeshape-construction patternValidation
pytest -k 'test_symbolic_shape_deduction or test_reshape_shape_output or test_reshape'This PR completes the
Reshapelimitation in the Relax ONNX frontend operator work tracked in #18945.