From 3ab90f5294347190cc6bbe9ef5092d3a79ba4b56 Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Mon, 5 Jan 2026 13:45:12 +0000 Subject: [PATCH 1/3] Arm backend: Add support for aten.slice_copy with non-unit step - Decompose strided slice_copy into unit-step slice_copy plus optional right padding and view_copy reshapes - Update SliceCopySupported check for the supported pattern - Add non-unit-step slice tests Change-Id: Ida60ee2f42c283d50c9e3185dca1f9ea2238cf83 Signed-off-by: Yufeng Shi --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + .../decompose_strided_slice_copy_pass.py | 146 +++++++++++++++ .../operator_support/slice_copy_support.py | 72 ++++++-- backends/arm/operators/op_slice.py | 7 +- backends/arm/test/ops/test_slice.py | 172 ++++++++++++++++++ 6 files changed, 383 insertions(+), 17 deletions(-) create mode 100644 backends/arm/_passes/decompose_strided_slice_copy_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 286abee1155..997a4452630 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -79,6 +79,7 @@ from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa from .decompose_sqrt_pass import DecomposeSqrtPass # noqa +from .decompose_strided_slice_copy_pass import DecomposeStridedSliceCopyPass # noqa from .decompose_sum_pass import DecomposeSumPass # noqa from .decompose_tan_pass import DecomposeTanPass # noqa from .decompose_tosa_unsupported_clamp_pass import ( # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 94406c7b4e1..8237494638b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -80,6 +80,7 @@ DecomposeSoftmaxPass, DecomposeSoftmaxUnstablePass, DecomposeSqrtPass, + DecomposeStridedSliceCopyPass, DecomposeSumPass, DecomposeTanPass, DecomposeTOSAUnsupportedClampPass, @@ -292,6 +293,7 @@ def _tosa_pipeline( DecomposeUnfoldToGatherPass(), DecomposeEmbeddingPass(), DecomposeIndexSelectToGatherPass(), + DecomposeStridedSliceCopyPass(), Conv1dUnsqueezePass(), ] ) diff --git a/backends/arm/_passes/decompose_strided_slice_copy_pass.py b/backends/arm/_passes/decompose_strided_slice_copy_pass.py new file mode 100644 index 00000000000..1ddf12e66e0 --- /dev/null +++ b/backends/arm/_passes/decompose_strided_slice_copy_pass.py @@ -0,0 +1,146 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +def _get_strided_slice_copy_decomposition(op): + """Return the operator overloads used by this decomposition.""" + if op == exir_ops.edge.aten.slice_copy.Tensor: + return ( + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.view_copy.default, + ) + raise RuntimeError(f"Can't get strided slice_copy decomposition for op {op}") + + +def _fixup_start(start, dim_size): + """Normalize start and clamp into [0, dim_size].""" + s = 0 if start is None else start + if s < 0: + s = s % dim_size + return max(0, min(s, dim_size)) + + +def _fixup_end(end, dim_size): + """Normalize end and clamp into [0, dim_size].""" + if end is None: + return dim_size + e = end + if e > dim_size: + e = dim_size + if e < 0: + e = e % dim_size + return max(0, min(e, dim_size)) + + +class DecomposeStridedSliceCopyPass(ArmPass): + """ + Decompose edge.aten.slice_copy.Tensor with non-unit step into supported ops. + + Given: + out = slice_copy(x, dim, start, end, step) with step > 1 + + Produce: + 1) y = slice_copy(x, dim, start, end, 1) # span with unit step + 2) pad y on the right to make length divisible by step (if needed) + 3) y2 = view_copy(y, ..., U, step, ...) # split the sliced dim + 4) y3 = slice_copy(y2, dim_i + 1, 0, 1, 1) # pick index 0 in each group + 5) out = view_copy(y3, ...) # collapse the singleton dim + + This implements "take every step-th element" using only unit-step slice + reshape. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + _TARGET_OPS = {exir_ops.edge.aten.slice_copy.Tensor} + + def call_operator(self, op, args, kwargs, meta): + if op not in self._TARGET_OPS: + return super().call_operator(op, args, kwargs, meta) + + # Only handle the non-unit-step case; leave unit-step to existing lowering. + if not (len(args) == 5 and args[4] != 1): + return super().call_operator(op, args, kwargs, meta) + + x, dim, start, end, step = args + assert step > 0, "slice_copy step must be positive" + + shape = x.data.shape + rank = len(shape) + + # Normalize dim into [0, rank). + dim_i = dim % rank + dim_size = shape[dim_i] + + # Normalize/clamp start/end into valid bounds. + start_i = _fixup_start(start, dim_size) + end_i = _fixup_end(end, dim_size) + + L = end_i - start_i + if L <= 0: + # slice_copy would return empty; keep default behavior. + return super().call_operator(op, args, kwargs, meta) + + slice_op, cat_op, view_op = _get_strided_slice_copy_decomposition(op) + + # 1) Unit-step slice of the requested span: + # y = x[..., start_i:end_i, ...] + y = super().call_operator( + slice_op, (x, dim_i, start_i, end_i, 1), {}, meta, updated=True + ) + + # 2) Compute: + # U = ceil(L / step) (# of output elements along dim_i) + # pad_right = U*step - L (so that padded length becomes U*step) + U = (L + step - 1) // step + pad_right = U * step - L + + # 3) If needed, right-pad along dim_i so that: + # after padding, y.shape[dim_i] == U*step + if pad_right > 0: + y_data = y.data + pad_shape = list(y_data.shape) + pad_shape[dim_i] = pad_right + + # z: zeros with same dtype/device as y, shape matches y except + # z.shape[dim_i] = pad_right. + fill_value = False if y_data.dtype == torch.bool else 0 + z = super().call_operator( + op=exir_ops.edge.aten.full.default, + args=(pad_shape, fill_value), + kwargs={"dtype": y_data.dtype, "device": y_data.device}, + meta=meta, + updated=True, + ) + + # Concatenate on the right: + # y.shape[dim_i] : L -> L + pad_right == U*step + y = super().call_operator(cat_op, ([y, z], dim_i), {}, meta, updated=True) + + # 4) Split the sliced dim: (U*step) -> (U, step) + y_t2 = y.data + split_shape = list(y_t2.shape) + split_shape[dim_i] = U + split_shape.insert(dim_i + 1, step) + + y2 = super().call_operator(view_op, (y, split_shape), {}, meta, updated=True) + + # 5) Take index 0 in the inserted "step" dimension: + # [..., U, step, ...] -> [..., U, 1, ...] + y3 = super().call_operator( + slice_op, (y2, dim_i + 1, 0, 1, 1), {}, meta, updated=True + ) + + # 6) Collapse y3's singleton step dim: [..., U, 1, ...] -> [..., U, ...]. + out_shape = list(y_t2.shape) # y_t2: [..., U*step, ...] + out_shape[dim_i] = U # out_shape: [..., U, ...] + + return super().call_operator(view_op, (y3, out_shape), {}, meta, updated=True) diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py index 77f3e97eb39..e3606711d85 100644 --- a/backends/arm/operator_support/slice_copy_support.py +++ b/backends/arm/operator_support/slice_copy_support.py @@ -1,15 +1,21 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Declare operator support for ``aten.slice_copy`` in TOSA. -Support slicing with unit step only; emit a warning and reject otherwise. +Rely on preprocessing (e.g. DecomposeStridedSliceCopyPass) to rewrite any +non-unit-step slicing into supported ops. Assume static shapes and constant +slicing parameters. -""" +Check: +- args length is 4 or 5 +- If present, require step > 0. +- Require dtype compatible with the selected TOSA profile (allow bool in both). -import logging +""" +import torch import torch.fx as fx from executorch.backends.arm.operator_support.tosa_supported_operators import ( register_tosa_support_check, @@ -18,8 +24,6 @@ from executorch.backends.arm.tosa import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops -logger = logging.getLogger(__name__) - @register_tosa_support_check class SliceCopySupported(SupportedTOSAOperatorCheck): @@ -35,14 +39,56 @@ class SliceCopySupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] - """Return True if the node is supported by TOSA. + if len(node.args) not in (4, 5): + self.reporter.report_reject( + node, + f"{node.target}: expected 4 or 5 args, got {len(node.args)}.", + ) + return False + + if len(node.args) == 5: + step = node.args[4] + if step <= 0: # type: ignore[operator] + self.reporter.report_reject( + node, + f"{node.target}: step must be > 0, got {step}.", + ) + return False - Accept slice_copy when the step is 1 (or unspecified). Warn and reject - non-unit step sizes. + values_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr] - """ - args = node.args - if len(args) == 5 and (step := args[4]) != 1: - logger.warning(f"{node.target} with step size of {step} not supported.") + SUPPORTED_INT_DTYPES = (torch.int8, torch.int16, torch.int32) + SUPPORTED_FLOAT_DTYPES = (torch.float16, torch.float32) + SUPPORTED_DTYPES = (torch.bool,) + SUPPORTED_INT_DTYPES + SUPPORTED_FLOAT_DTYPES + + # bool is supported in both INT and FP profiles + if values_dtype == torch.bool: + return True + # ints require INT profile + elif values_dtype in SUPPORTED_INT_DTYPES: + if not tosa_spec.support_integer(): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires INT profile.", + ) + return False + + # fp16/fp32: either FP profile, or INT profile (via quantization) + elif values_dtype in SUPPORTED_FLOAT_DTYPES: + if not (tosa_spec.support_float() or tosa_spec.support_integer()): + self.reporter.report_reject( + node, + f"{node.target}: dtype {values_dtype} requires FP profile or " + "INT profile (with quantization).", + ) + return False + + else: + self.reporter.report_reject( + node, + f"{node.target}: unsupported values dtype {values_dtype}; " + f"expected one of {SUPPORTED_DTYPES}.", + ) return False + return True diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 9dac73d84ba..004c2e353f8 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -83,13 +83,12 @@ def define_node( self.tosa_spec, ) - # See slice_copy_support.py + # TOSA.SLICE has no stride parameter. Any non-unit-step slice_copy must have been + # rewritten earlier (e.g. by DecomposeStridedSliceCopyPass), so only step=1 is legal here. if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): raise ValueError("Unsupported combination of inputs") - # aten.slice_copy supports slicing in 1d at a time. - # The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride. - input_node, dim, start, end = inputs + input_node, dim, start, end = inputs[:4] # Translate and check parameters in Pytorch dim order. shape = input_node.shape diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 00498e6e1be..0adb3d1b113 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -180,3 +180,175 @@ def test_slice_tensor_16a8w_u85_INT(test_data: torch.Tensor): get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) ) pipeline.run() + + +#################################### +## Non-unit step / stride slicing ## +#################################### + +input_t_step = Tuple[torch.Tensor, int, int, int, int] # (x, dim, start, end, step) + + +test_data_step_fp = { + # x[0:10:2] == x[::2] + "arange_fp32_1d_step2": lambda: ( + torch.arange(10, dtype=torch.float32), + 0, + 0, + 10, + 2, + ), + # x[:, 1:10:4] + "arange_fp32_2d_step4": lambda: ( + torch.arange(40, dtype=torch.float32).reshape(4, 10), + 1, + 1, + 10, + 4, + ), + # x[:, 0:4:2, :] + "arange_fp32_3d_dim1_step2": lambda: ( + torch.arange(2 * 4 * 17, dtype=torch.float32).reshape(2, 4, 17), + 1, + 0, + 4, + 2, + ), + # x[:, :, :, 0:17:4] + "arange_fp32_4d_dim3_step4": lambda: ( + torch.arange(2 * 3 * 5 * 17, dtype=torch.float32).reshape(2, 3, 5, 17), + 3, + 0, + 17, + 4, + ), + # x[:, 0:12:4] + "bool_2d_step4": lambda: ( + (torch.rand((2, 12)) < 0.5), # [2,12], dtype=bool + 1, + 0, + 12, + 4, + ), +} + +test_data_step_int = { + # x[:, 0:9:3] + "rand_int8_2d_step3": lambda: ( + torch.randint(-8, 8, size=(3, 9), dtype=torch.int8), + 1, + 0, + 9, + 3, + ), + # x[:, 0:6:2, :] + "arange_int32_3d_step2_dim1": lambda: ( + torch.arange(2 * 6 * 4, dtype=torch.int32).reshape(2, 6, 4), + 1, + 0, + 6, + 2, + ), + # x[:, :, :, 0:19:4] + "arange_int8_4d_dim3_step4": lambda: ( + torch.arange(2 * 2 * 4 * 19, dtype=torch.int8).reshape(2, 2, 4, 19), + 3, + 0, + 19, + 4, + ), + # x[:, 0:12:4] + "bool_2d_step4": lambda: ( + (torch.rand((2, 12)) < 0.5), # [2,12], dtype=bool + 1, + 0, + 12, + 4, + ), +} + + +class SliceWithStep(torch.nn.Module): + def forward( + self, x: torch.Tensor, dim_: int, start_: int, end_: int, step_: int + ) -> torch.Tensor: + # Use aten.slice to generate a slice_copy in Edge for lowering. + return torch.ops.aten.slice.Tensor(x, dim_, start_, end_, step_) + + +@common.parametrize("test_data", test_data_step_fp) +def test_slice_tensor_step_tosa_FP(test_data: Tuple): + pipeline = TosaPipelineFP[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +def test_slice_tensor_step_tosa_INT(test_data: Tuple): + pipeline = TosaPipelineINT[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_step_int | test_data_step_fp, + xfails={ + "bool_2d_step4": "MLETORCH-1744: bool test fails", + }, +) +@common.XfailIfNoCorstone300 +def test_slice_tensor_step_u55_INT(test_data: Tuple): + pipeline = EthosU55PipelineINT[input_t1]( + SliceWithStep(), + test_data(), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +@common.XfailIfNoCorstone320 +def test_slice_tensor_step_u85_INT(test_data: Tuple): + pipeline = EthosU85PipelineINT[input_t1]( + SliceWithStep(), + test_data(), + aten_ops=aten_op, + exir_ops=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +@common.SkipIfNoModelConverter +def test_slice_tensor_step_vgf_no_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_step_int | test_data_step_fp) +@common.SkipIfNoModelConverter +def test_slice_tensor_step_vgf_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t_step]( + SliceWithStep(), + test_data(), + aten_op=aten_op, + exir_op=exir_op, + quantize=True, + ) + pipeline.run() From c076d9aa584b064d8a6623564187f45309cbcb1a Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Mon, 26 Jan 2026 16:02:38 +0000 Subject: [PATCH 2/3] Arm backend: Fix naming convention Change-Id: I8479a4daed1e8be4bda4591e04dfbcadc297ab4f Signed-off-by: Yufeng Shi --- backends/arm/test/ops/test_slice.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/backends/arm/test/ops/test_slice.py b/backends/arm/test/ops/test_slice.py index 0adb3d1b113..d9108f4121a 100644 --- a/backends/arm/test/ops/test_slice.py +++ b/backends/arm/test/ops/test_slice.py @@ -277,7 +277,7 @@ def forward( @common.parametrize("test_data", test_data_step_fp) -def test_slice_tensor_step_tosa_FP(test_data: Tuple): +def test_slice_tensor_tosa_FP_step(test_data: Tuple): pipeline = TosaPipelineFP[input_t_step]( SliceWithStep(), test_data(), @@ -288,7 +288,7 @@ def test_slice_tensor_step_tosa_FP(test_data: Tuple): @common.parametrize("test_data", test_data_step_int | test_data_step_fp) -def test_slice_tensor_step_tosa_INT(test_data: Tuple): +def test_slice_tensor_tosa_INT_step(test_data: Tuple): pipeline = TosaPipelineINT[input_t_step]( SliceWithStep(), test_data(), @@ -306,7 +306,7 @@ def test_slice_tensor_step_tosa_INT(test_data: Tuple): }, ) @common.XfailIfNoCorstone300 -def test_slice_tensor_step_u55_INT(test_data: Tuple): +def test_slice_tensor_u55_INT_step(test_data: Tuple): pipeline = EthosU55PipelineINT[input_t1]( SliceWithStep(), test_data(), @@ -318,7 +318,7 @@ def test_slice_tensor_step_u55_INT(test_data: Tuple): @common.parametrize("test_data", test_data_step_int | test_data_step_fp) @common.XfailIfNoCorstone320 -def test_slice_tensor_step_u85_INT(test_data: Tuple): +def test_slice_tensor_u85_INT_step(test_data: Tuple): pipeline = EthosU85PipelineINT[input_t1]( SliceWithStep(), test_data(), @@ -330,7 +330,7 @@ def test_slice_tensor_step_u85_INT(test_data: Tuple): @common.parametrize("test_data", test_data_step_int | test_data_step_fp) @common.SkipIfNoModelConverter -def test_slice_tensor_step_vgf_no_quant(test_data: Tuple): +def test_slice_tensor_vgf_no_quant_step(test_data: Tuple): pipeline = VgfPipeline[input_t_step]( SliceWithStep(), test_data(), @@ -343,7 +343,7 @@ def test_slice_tensor_step_vgf_no_quant(test_data: Tuple): @common.parametrize("test_data", test_data_step_int | test_data_step_fp) @common.SkipIfNoModelConverter -def test_slice_tensor_step_vgf_quant(test_data: Tuple): +def test_slice_tensor_vgf_quant_step(test_data: Tuple): pipeline = VgfPipeline[input_t_step]( SliceWithStep(), test_data(), From 9af4e5463c78cc4bc7339572e2a58102e0b518db Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Tue, 27 Jan 2026 11:00:59 +0000 Subject: [PATCH 3/3] Arm backend: Add xfail for the failed test Change-Id: I06319f5378a3fed60f5f20499671389dbd0e68e8 Signed-off-by: Yufeng Shi --- .../stable_diffusion/test_CLIPTextModelWithProjection.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py index b68ad3979f1..f11310eb24c 100644 --- a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -6,6 +6,8 @@ from typing import Tuple +import pytest + import torch from executorch.backends.arm._passes import ( ConvertInt64ConstOpsToInt32Pass, @@ -76,6 +78,9 @@ def prepare_model_and_inputs(self): return text_encoder_model, text_encoder_model_inputs +@pytest.mark.xfail( + reason="MLETORCH-1601: Delegate output order mismatch from TOSA reference model." +) def test_clip_text_with_projection_tosa_FP(): text_encoder_model, text_encoder_model_inputs = ( TestCLIPTextModelWithProjection().prepare_model_and_inputs()