Skip to content

Comments

Fix static shape inference in pt.linalg.kron and add regression test#1898

Open
ayulockedin wants to merge 4 commits intopymc-devs:mainfrom
ayulockedin:fix-kron-static-shape
Open

Fix static shape inference in pt.linalg.kron and add regression test#1898
ayulockedin wants to merge 4 commits intopymc-devs:mainfrom
ayulockedin:fix-kron-static-shape

Conversation

@ayulockedin
Copy link

Description

This PR resolves the issue where pt.linalg.kron destroys static shape information, returning (None, None) even when input shapes are fully known.

The previous implementation relied on a vector-wise symbolic multiplication of shapes:
out_shape = tuple(a.shape * b.shape)

This forced the underlying Reshape Op to treat the entire shape vector as a single symbolic entity, which prevented the ShapeFeature from constant-folding individual dimensions into their static values.

The Fix

I implemented element-wise symbolic multiplication for the output shape:
[a.shape[i] * b.shape[i] for i in range(a.ndim)]

This provides the shape inference engine with enough granularity to resolve static constants (e.g., 4 * 3 = 12) at compile time while maintaining the symbolic integrity required for downstream operations like clone_replace.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ayulockedin
Copy link
Author

@jessegrabowski can you have a look at this when u have a moment thx

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solution looks good. Just needs some cleanup.

ayulockedin and others added 2 commits February 20, 2026 14:06
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com>
np_val = np.kron(a, b)

# Regression test for issue #1867
assert out.shape == np_val.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out is a numpy array. You need to test the shape of symbolic kron(x, y), not the numerical output of the compiled function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I was asserting against the compiled numerical output instead of the symbolic graph. I've updated the test to instantiate explicitly static tensors and assert against kron(x_static, y_static).type.shape to ensure the static shape information is properly preserved by the shape inference engine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

pt.linalg.kron destroys static shape information

2 participants