Skip to content

[Relax][ONNX] Complete ShapeExpr reshape handling in ONNX frontend#18956

Merged
tlopex merged 1 commit intoapache:mainfrom
LudovicoYIN:relax-onnx-reshape-shapeexpr
Mar 30, 2026
Merged

[Relax][ONNX] Complete ShapeExpr reshape handling in ONNX frontend#18956
tlopex merged 1 commit intoapache:mainfrom
LudovicoYIN:relax-onnx-reshape-shapeexpr

Conversation

@LudovicoYIN
Copy link
Copy Markdown
Contributor

Summary

Complete Reshape handling for shape values in the Relax ONNX frontend.

Changes

  • keep ShapeExpr -> Reshape([-1]) on the shape-specialized path
  • materialize ShapeExpr to an int64 tensor for other reshape targets and apply regular tensor reshape semantics
  • add frontend coverage for Shape -> Reshape([-1])
  • add frontend coverage for reshaping shape outputs to non-[-1] targets such as [1, 3] and [3, 1]
  • extend symbolic shape deduction coverage to include the common Shape -> Reshape([-1]) -> Gather -> Unsqueeze shape-construction pattern

Validation

  • pytest -k 'test_symbolic_shape_deduction or test_reshape_shape_output or test_reshape'

This PR completes the Reshape limitation in the Relax ONNX frontend operator work tracked in #18945.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the ONNX frontend to improve the handling of Reshape operations when applied to relax.ShapeExpr inputs. It ensures that identity flattens are preserved to maintain specialized shape handling, while other reshape targets are correctly converted to tensor operations. Corresponding tests were added and updated to verify these changes, including symbolic shape deduction. Feedback was provided to refactor the test node construction for better readability and conciseness.

Comment on lines 3673 to +3692
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
gather_node = helper.make_node("Gather", ["shape_output", "indices"], ["gather_output"])
nodes = [index_node, shape_node]
gather_input = "shape_output"

if with_reshape_flatten:
reshape_node = helper.make_node(
"Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]
)
nodes.append(reshape_node)
gather_input = "reshaped_shape"

gather_node = helper.make_node("Gather", [gather_input, "indices"], ["gather_output"])
unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"])
constant_of_shape_node = helper.make_node(
"ConstantOfShape",
["unsqueeze_output"],
["output"],
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
)
nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved readability and conciseness, the construction of the nodes list can be refactored. Creating the nodes directly within nodes.append and nodes.extend calls avoids unnecessary intermediate variables, making the graph's structure more direct and easier to understand.

Suggested change
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
gather_node = helper.make_node("Gather", ["shape_output", "indices"], ["gather_output"])
nodes = [index_node, shape_node]
gather_input = "shape_output"
if with_reshape_flatten:
reshape_node = helper.make_node(
"Reshape", ["shape_output", "target_shape"], ["reshaped_shape"]
)
nodes.append(reshape_node)
gather_input = "reshaped_shape"
gather_node = helper.make_node("Gather", [gather_input, "indices"], ["gather_output"])
unsqueeze_node = helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"])
constant_of_shape_node = helper.make_node(
"ConstantOfShape",
["unsqueeze_output"],
["output"],
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
)
nodes.extend([gather_node, unsqueeze_node, constant_of_shape_node])
shape_node = helper.make_node("Shape", ["data"], ["shape_output"])
nodes = [index_node, shape_node]
gather_input = "shape_output"
if with_reshape_flatten:
nodes.append(
helper.make_node("Reshape", ["shape_output", "target_shape"], ["reshaped_shape"])
)
gather_input = "reshaped_shape"
nodes.extend([
helper.make_node("Gather", [gather_input, "indices"], ["gather_output"]),
helper.make_node("Unsqueeze", ["gather_output", "axes"], ["unsqueeze_output"]),
helper.make_node(
"ConstantOfShape",
["unsqueeze_output"],
["output"],
value=helper.make_tensor("value", TensorProto.FLOAT, [], [1]),
),
])

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you!

@tlopex tlopex merged commit c79caf0 into apache:main Mar 30, 2026
10 checks passed
@LudovicoYIN LudovicoYIN deleted the relax-onnx-reshape-shapeexpr branch March 31, 2026 01:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants