Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa
from .decompose_strided_slice_copy_pass import DecomposeStridedSliceCopyPass # noqa
from .decompose_sum_pass import DecomposeSumPass # noqa
from .decompose_tan_pass import DecomposeTanPass # noqa
from .decompose_tosa_unsupported_clamp_pass import ( # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
DecomposeSoftmaxPass,
DecomposeSoftmaxUnstablePass,
DecomposeSqrtPass,
DecomposeStridedSliceCopyPass,
DecomposeSumPass,
DecomposeTanPass,
DecomposeTOSAUnsupportedClampPass,
Expand Down Expand Up @@ -293,6 +294,7 @@ def _tosa_pipeline(
DecomposeUnfoldToGatherPass(),
DecomposeEmbeddingPass(),
DecomposeIndexSelectToGatherPass(),
DecomposeStridedSliceCopyPass(),
Conv1dUnsqueezePass(),
]
)
Expand Down
146 changes: 146 additions & 0 deletions backends/arm/_passes/decompose_strided_slice_copy_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


def _get_strided_slice_copy_decomposition(op):
"""Return the operator overloads used by this decomposition."""
if op == exir_ops.edge.aten.slice_copy.Tensor:
return (
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.view_copy.default,
)
raise RuntimeError(f"Can't get strided slice_copy decomposition for op {op}")


def _fixup_start(start, dim_size):
"""Normalize start and clamp into [0, dim_size]."""
s = 0 if start is None else start
if s < 0:
s = s % dim_size
return max(0, min(s, dim_size))


def _fixup_end(end, dim_size):
"""Normalize end and clamp into [0, dim_size]."""
if end is None:
return dim_size
e = end
if e > dim_size:
e = dim_size
if e < 0:
e = e % dim_size
return max(0, min(e, dim_size))


class DecomposeStridedSliceCopyPass(ArmPass):
"""
Decompose edge.aten.slice_copy.Tensor with non-unit step into supported ops.

Given:
out = slice_copy(x, dim, start, end, step) with step > 1

Produce:
1) y = slice_copy(x, dim, start, end, 1) # span with unit step
2) pad y on the right to make length divisible by step (if needed)
3) y2 = view_copy(y, ..., U, step, ...) # split the sliced dim
4) y3 = slice_copy(y2, dim_i + 1, 0, 1, 1) # pick index 0 in each group
5) out = view_copy(y3, ...) # collapse the singleton dim

This implements "take every step-th element" using only unit-step slice + reshape.
"""

_passes_required_after: Set[Type[ExportPass]] = set()
_TARGET_OPS = {exir_ops.edge.aten.slice_copy.Tensor}

def call_operator(self, op, args, kwargs, meta):
if op not in self._TARGET_OPS:
return super().call_operator(op, args, kwargs, meta)

# Only handle the non-unit-step case; leave unit-step to existing lowering.
if not (len(args) == 5 and args[4] != 1):
return super().call_operator(op, args, kwargs, meta)

x, dim, start, end, step = args
assert step > 0, "slice_copy step must be positive"

shape = x.data.shape
rank = len(shape)

# Normalize dim into [0, rank).
dim_i = dim % rank
dim_size = shape[dim_i]

# Normalize/clamp start/end into valid bounds.
start_i = _fixup_start(start, dim_size)
end_i = _fixup_end(end, dim_size)

L = end_i - start_i
if L <= 0:
# slice_copy would return empty; keep default behavior.
return super().call_operator(op, args, kwargs, meta)

slice_op, cat_op, view_op = _get_strided_slice_copy_decomposition(op)

# 1) Unit-step slice of the requested span:
# y = x[..., start_i:end_i, ...]
y = super().call_operator(
slice_op, (x, dim_i, start_i, end_i, 1), {}, meta, updated=True
)

# 2) Compute:
# U = ceil(L / step) (# of output elements along dim_i)
# pad_right = U*step - L (so that padded length becomes U*step)
U = (L + step - 1) // step
pad_right = U * step - L

# 3) If needed, right-pad along dim_i so that:
# after padding, y.shape[dim_i] == U*step
if pad_right > 0:
y_data = y.data
pad_shape = list(y_data.shape)
pad_shape[dim_i] = pad_right

# z: zeros with same dtype/device as y, shape matches y except
# z.shape[dim_i] = pad_right.
fill_value = False if y_data.dtype == torch.bool else 0
z = super().call_operator(
op=exir_ops.edge.aten.full.default,
args=(pad_shape, fill_value),
kwargs={"dtype": y_data.dtype, "device": y_data.device},
meta=meta,
updated=True,
)

# Concatenate on the right:
# y.shape[dim_i] : L -> L + pad_right == U*step
y = super().call_operator(cat_op, ([y, z], dim_i), {}, meta, updated=True)

# 4) Split the sliced dim: (U*step) -> (U, step)
y_t2 = y.data
split_shape = list(y_t2.shape)
split_shape[dim_i] = U
split_shape.insert(dim_i + 1, step)

y2 = super().call_operator(view_op, (y, split_shape), {}, meta, updated=True)

# 5) Take index 0 in the inserted "step" dimension:
# [..., U, step, ...] -> [..., U, 1, ...]
y3 = super().call_operator(
slice_op, (y2, dim_i + 1, 0, 1, 1), {}, meta, updated=True
)

# 6) Collapse y3's singleton step dim: [..., U, 1, ...] -> [..., U, ...].
out_shape = list(y_t2.shape) # y_t2: [..., U*step, ...]
out_shape[dim_i] = U # out_shape: [..., U, ...]

return super().call_operator(view_op, (y3, out_shape), {}, meta, updated=True)
72 changes: 59 additions & 13 deletions backends/arm/operator_support/slice_copy_support.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Declare operator support for ``aten.slice_copy`` in TOSA.
Support slicing with unit step only; emit a warning and reject otherwise.
Rely on preprocessing (e.g. DecomposeStridedSliceCopyPass) to rewrite any
non-unit-step slicing into supported ops. Assume static shapes and constant
slicing parameters.
"""
Check:
- args length is 4 or 5
- If present, require step > 0.
- Require dtype compatible with the selected TOSA profile (allow bool in both).
import logging
"""

import torch
import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
Expand All @@ -18,8 +24,6 @@
from executorch.backends.arm.tosa import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)


@register_tosa_support_check
class SliceCopySupported(SupportedTOSAOperatorCheck):
Expand All @@ -35,14 +39,56 @@ class SliceCopySupported(SupportedTOSAOperatorCheck):
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]
"""Return True if the node is supported by TOSA.
if len(node.args) not in (4, 5):
self.reporter.report_reject(
node,
f"{node.target}: expected 4 or 5 args, got {len(node.args)}.",
)
return False

if len(node.args) == 5:
step = node.args[4]
if step <= 0: # type: ignore[operator]
self.reporter.report_reject(
node,
f"{node.target}: step must be > 0, got {step}.",
)
return False

Accept slice_copy when the step is 1 (or unspecified). Warn and reject
non-unit step sizes.
values_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr]

"""
args = node.args
if len(args) == 5 and (step := args[4]) != 1:
logger.warning(f"{node.target} with step size of {step} not supported.")
SUPPORTED_INT_DTYPES = (torch.int8, torch.int16, torch.int32)
SUPPORTED_FLOAT_DTYPES = (torch.float16, torch.float32)
SUPPORTED_DTYPES = (torch.bool,) + SUPPORTED_INT_DTYPES + SUPPORTED_FLOAT_DTYPES

# bool is supported in both INT and FP profiles
if values_dtype == torch.bool:
return True
# ints require INT profile
elif values_dtype in SUPPORTED_INT_DTYPES:
if not tosa_spec.support_integer():
self.reporter.report_reject(
node,
f"{node.target}: dtype {values_dtype} requires INT profile.",
)
return False

# fp16/fp32: either FP profile, or INT profile (via quantization)
elif values_dtype in SUPPORTED_FLOAT_DTYPES:
if not (tosa_spec.support_float() or tosa_spec.support_integer()):
self.reporter.report_reject(
node,
f"{node.target}: dtype {values_dtype} requires FP profile or "
"INT profile (with quantization).",
)
return False

else:
self.reporter.report_reject(
node,
f"{node.target}: unsupported values dtype {values_dtype}; "
f"expected one of {SUPPORTED_DTYPES}.",
)
return False

return True
7 changes: 3 additions & 4 deletions backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,12 @@ def define_node(
self.tosa_spec,
)

# See slice_copy_support.py
# TOSA.SLICE has no stride parameter. Any non-unit-step slice_copy must have been
# rewritten earlier (e.g. by DecomposeStridedSliceCopyPass), so only step=1 is legal here.
if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)):
raise ValueError("Unsupported combination of inputs")

# aten.slice_copy supports slicing in 1d at a time.
# The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride.
input_node, dim, start, end = inputs
input_node, dim, start, end = inputs[:4]

# Translate and check parameters in Pytorch dim order.
shape = input_node.shape
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import pytest

import torch
from executorch.backends.arm._passes import (
ConvertInt64ConstOpsToInt32Pass,
Expand Down Expand Up @@ -76,6 +78,9 @@ def prepare_model_and_inputs(self):
return text_encoder_model, text_encoder_model_inputs


@pytest.mark.xfail(
reason="MLETORCH-1601: Delegate output order mismatch from TOSA reference model."
)
def test_clip_text_with_projection_tosa_FP():
text_encoder_model, text_encoder_model_inputs = (
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
Expand Down
Loading
Loading