Skip to content

Commit c84885b

Browse files
committed
Add xtensor hashing and restore is_full_slice for pymc-extras compatibility
1 parent 2cfa223 commit c84885b

3 files changed

Lines changed: 18 additions & 1 deletion

File tree

pytensor/tensor/rewriting/subtensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def transform_take(a, indices, axis):
162162
return transform_take(a, indices.flatten(), axis).reshape(shape, ndim=ndim)
163163

164164

165+
def is_full_slice(x):
166+
# Replace this function in pymc-extras and pymc with x==slice(None)
167+
return x == slice(None)
168+
169+
165170
def get_advsubtensor_axis(indices):
166171
"""Determine the axis at which an array index is applied.
167172

pytensor/tensor/subtensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2870,7 +2870,6 @@ def get_slice_val(comp):
28702870
start_val = get_slice_val(entry.start)
28712871
stop_val = get_slice_val(entry.stop)
28722872
step_val = get_slice_val(entry.step)
2873-
28742873
full_indices.append(slice(start_val, stop_val, step_val))
28752874
else:
28762875
assert isinstance(entry, int)

pytensor/xtensor/indexing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,19 @@ def __init__(self, mode: Literal["set", "inc"], idx_list):
326326
self.mode = mode
327327
self.idx_list = idx_list
328328

329+
def __hash__(self):
330+
"""Hash using mode and idx_list. Slices are not hashable in Python < 3.12."""
331+
return hash((type(self), self.mode, self._hashable_idx_list()))
332+
333+
def _hashable_idx_list(self):
334+
"""Return a hashable version of idx_list (slices converted to tuples)."""
335+
return tuple(
336+
(slice, entry.start, entry.stop, entry.step)
337+
if isinstance(entry, slice)
338+
else entry
339+
for entry in self.idx_list
340+
)
341+
329342
def make_node(self, x, y, x_view, *index_inputs):
330343
try:
331344
y = as_xtensor(y)

0 commit comments

Comments
 (0)