Skip to content

Comments

Refactor advanced subtensor#1756

Open
jaanerik wants to merge 1 commit intopymc-devs:mainfrom
jaanerik:refactor-advanced-subtensor
Open

Refactor advanced subtensor#1756
jaanerik wants to merge 1 commit intopymc-devs:mainfrom
jaanerik:refactor-advanced-subtensor

Conversation

@jaanerik
Copy link

Description

Allows vectorizing AdvancedSetSubtensor.

Gemini picks up where Copilot left off.

Related Issue

Checklist

  • Checked that the pre-commit linting/style checks pass (ruff removed whitespace etc from copilot commits also I think)
  • Changed tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks)
  • If you are a pro: each commit corresponds to a relevant logical change

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch 2 times, most recently from 3ce54c7 to 92f61ed Compare December 1, 2025 08:23
@jaanerik jaanerik marked this pull request as ready for review December 2, 2025 09:49
@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch from 92f61ed to a6cb68d Compare December 2, 2025 11:23
@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch 3 times, most recently from 546100c to 4b02064 Compare December 9, 2025 12:21
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking pretty good.

Mostly I want to believe there is still room to simplify things / reuse code.
This is also a good opportunity to simplify the idx_list. There's no reason to use ScalarTypes in the dummy slices, and it's complicating our equality and hashing.

What about using simple integers to indicate what is the role of each index variable?

old_idx_list = (ps.int64, slice(ps.int64, None, None), ps.int64, slice(ps.int64, None, ps.int64))
new_idx_list = (0, slice(1, None, None), 2, slice(3, None, 4))

Having the indices could probably come in handy anyway. With this we shouldn't need a custom hash / eq, we can just use the default one from __props__.

else:
x, y, *idxs = node.inputs
x, y = node.inputs[0], node.inputs[1]
tensor_inputs = node.inputs[2:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the name tensor_inputs, x, y are also tensor and inputs. Use index_variables?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This applies elsewhere

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using index_variables now.

Copy link
Member

@ricardoV94 ricardoV94 Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still using tensor_inputs in the last code you pushed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, renamed now.

# must already be raveled in the original graph, so we don't need to do anything to it
new_out = node.op(raveled_x, y, *new_idxs)
# But we must reshape the output to math the original shape
new_out = AdvancedIncSubtensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should use type(op) so that subclasses are respected. It may also make sense to add a method to these indexing Ops like op.with_new_indices() that clones itself with a new idx_list. Maybe that will be the one that handles creating the new idx_list, instead of having to be here in the rewrite.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently functions ravel_multidimensional_bool_idx and ravel_multidimensional_bool_idx don't assume the subclass in the same way, but it'd be nice if you could check. Also, if I am giving up on some rewrites too quickly here, please let me know.

Comment on lines 2602 to 2629
def __init__(self, idx_list):
"""
Initialize AdvancedSubtensor with index list.

Parameters
----------
idx_list : tuple
List of indices where slices are stored as-is,
and numerical indices are replaced by their types.
"""
self.idx_list = tuple(
index_vars_to_types(idx, allow_advanced=True) for idx in idx_list
)
# Store expected number of tensor inputs for validation
self.expected_inputs_len = len(
get_slice_elements(self.idx_list, lambda entry: isinstance(entry, Type))
)

def __hash__(self):
msg = []
for entry in self.idx_list:
if isinstance(entry, slice):
msg += [(entry.start, entry.stop, entry.step)]
else:
msg += [entry]

idx_list = tuple(msg)
return hash((type(self), idx_list))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already exists in Subtensor? If so create a BaseSubtensor class that handles idx_list and hash/equality based on it.

Make all Subtensor operations inherit from it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this advice may make no sense if we simplify the idx_list to not need custom hash / eq

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hash is a bit different based on whether it is *IncSubtensor or not in the current implementation. Wrote about the Python 3.11 slice not being hashable below.

)
else:
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
# With the new interface, all inputs are tensors, so Blockwise can handle them
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment should not mention a specific time period. Previous status is not relevant here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we still want to avoid Blockwise eagerly if we can

Copy link
Author

@jaanerik jaanerik Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All time periods should be removed from comments that were added in this PR.

pattern.append("x")
new_args.append(slice(None))
else:
# Check for boolean index which consumes multiple dimensions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably right, but why does it matter that boolean indexing consumes multiple dimensions? Aren't we doing expand_dims where there was None -> replace new_axis by None slice -> index again?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a more dynamic approach, but I have had trouble with multidimensional bool arrays and ellipsis, e.g. x[multi_dim_bool_tensor,None,...] will not know where to add the new axis.

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 9, 2025

Just wanted to repeat, this is looking great. Thanks so far @jaanerik

I'm being picky because indexing is a pretty fundamental operation, so want to make sure we get it right this time.

@jaanerik
Copy link
Author

This is looking pretty good.

Mostly I want to believe there is still room to simplify things / reuse code. This is also a good opportunity to simplify the idx_list. There's no reason to use ScalarTypes in the dummy slices, and it's complicating our equality and hashing.

What about using simple integers to indicate what is the role of each index variable?

old_idx_list = (ps.int64, slice(ps.int64, None, Non), ps.int64, slice(3, ps.int64, 4))
new_idx_list = (0, slice(1, None, None), 2, slice(3, None, 4))

Having the indices could probably come in handy anyway. With this we shouldn't need a custom hash / eq, we can just use the default one from __props__.

@ricardoV94 I am struggling a bit with understanding your example, because old_idx_list also has ints. Could you clarify how you see constant index and variable index both working here.

If you have the last slice of old_idx_list as slice(3, ps.int64, 4) and new one has slice(3, 4, 5) (this does not correspond exactly to your example, because I didn't understand it) then would 3 from new_idx_list indicate a Constant(3), 4 would indicate a Scalar() and 5 would indicate Constant(4)? Or would you discriminate between a static index and a variable index in a different way?

Absolutely no problem with being picky. I am very grateful for the feedback :)

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 14, 2025

I updated it, it was some copy-paste typos. Old_idx doesn't have ints, only ps.int64 to mark the place in the pattern that have input variables.

@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch 4 times, most recently from 4787bc9 to 5ba887b Compare December 30, 2025 00:41
@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch 2 times, most recently from d821794 to 02a8d24 Compare January 6, 2026 13:37
@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch 3 times, most recently from c84885b to e127e4b Compare February 3, 2026 14:52
@jaanerik
Copy link
Author

jaanerik commented Feb 4, 2026

Hmmm, so we should simply start counting from the first indexing, regardless of whether it's Subtensor or AdvancedSubtensor? That would allow us to compare the idx_list directly.

Start both at zero. x[idx] and x[idx].set(y) both have idx_list == [0]

Refactored to use ==. Could not respond to or resolve some comments directly, idk why.

Removed many comments, removed most makeslice, slicetype usage, but refactoring xtensor is a bigger change and is still quite verbose. Would like some feedback whether I should touch that code at all before continuing :) @ricardoV94

@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch from 0b51466 to 13d025a Compare February 4, 2026 12:48
@ricardoV94
Copy link
Member

Let's try not to touch XTensor if we can, to keep this PR moving along. If it's using SliceTypes let it remain for those

Comment on lines +155 to 156
elif len(shape_parts) == 1:
shape = shape_parts[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a bug fix? If so can we add a regression test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test_transform_take_scalar_index() in tests/tensor/rewriting/test_subtensor.py. A bug fix for the case where all shape_parts are empty tuples (e.g., scalar index on 1D array).

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking nearly done!

I saw you did some changes to xtensor. You mentioned that was growing out of scope so I didn't review. Let me know how you want to proceed there, or if you need guidance.

I flagged some changes that seem to be a conflict from main mis-resolved?

Otherwise it's really shaping up!!!

@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch 2 times, most recently from f4b4b8c to 4ec9707 Compare February 10, 2026 15:17
@jaanerik
Copy link
Author

I currently simply reverted the xtensor refactor assuming it's okay to be labelled as out of scope for this PR. There was a very small refactor for xtensor that needed to stay in.

Thanks for your help. Asking you for another review : ) @ricardoV94

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty much done. Just some small nits / suggestions. I'll try to take care of them this week if you don't get to them before

"""
counter = [0]
self.idx_list = tuple(
index_vars_to_positions(entry, counter, allow_advanced=allow_advanced)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be done by the Op. Just make it receive the correct idx_list and maybe validate it. The allow_advanced at the base op also seems suspect

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to refactor it, but please take a look.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a full review at last.

Pushing back against one of the helpers that seems to be doing way too much (tbf it was already a mess before this PR), and a bunch of new tests / test changes.

Every single new test that was supposed to be a bug regression (that I tested) seems to pass just fine in main, which makes me suspect that it's regression against intermediate breaking changes from previous iterations. I would revert those.

@ricardoV94
Copy link
Member

@jaanerik, somewhat unrelated, would you like to join our internal dev discord? This contribution is pretty massive so you definitely earned a ticket

@jaanerik
Copy link
Author

@jaanerik, somewhat unrelated, would you like to join our internal dev discord? This contribution is pretty massive so you definitely earned a ticket

I joined a pymc discord. Is that the one you meant? Wrote to #chat

@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch from 64c08a5 to c546c55 Compare February 18, 2026 15:38
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 suggestion, 1 nit, and 3 places I noted a wrong rebase from main.

I think this is the final round at last :)

@jaanerik jaanerik force-pushed the refactor-advanced-subtensor branch from 96b5945 to b0bdf68 Compare February 20, 2026 10:18
@jaanerik
Copy link
Author

Rebased (vectorize_node in xtensor was there before rebase as well, but due to a git conflict I added it back later). Also scan/rewriting was refactored because some numba and jax tests were failing.

@ricardoV94 ricardoV94 force-pushed the refactor-advanced-subtensor branch from b0bdf68 to 528d91e Compare February 20, 2026 17:38
@ricardoV94
Copy link
Member

@jaanerik I cleaned a few more things, squashed all commits and force pushed. I removed the new vectorize_node, we should do those in a separate PR, and I think they were masking a bug in the pre-existing rewrite which assumed vectorized AdvancedIncSubtensor with slices would never show up (because it was just not possible before)

@ricardoV94 ricardoV94 force-pushed the refactor-advanced-subtensor branch from 528d91e to 8c6052b Compare February 20, 2026 18:12
- newaxis is handled as explicit DimShuffel on the inputs
- slices are encoded internally, so the Ops only take numerical inputs

Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com>
@ricardoV94 ricardoV94 force-pushed the refactor-advanced-subtensor branch from 8c6052b to 9016175 Compare February 20, 2026 23:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reconsider use of SliceType and NoneType Variables as inputs to AdvancedIndexing

2 participants