diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index 3ef5fc02adb..19665f37083 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -8,6 +8,7 @@ from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa from .decompose_hardswish_pass import DecomposeHardswishPass # noqa from .decompose_mean_pass import DecomposeMeanPass # noqa +from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/activation_fusion_pass.py b/backends/cortex_m/passes/activation_fusion_pass.py index a53c065aaa4..ff61f3493dd 100644 --- a/backends/cortex_m/passes/activation_fusion_pass.py +++ b/backends/cortex_m/passes/activation_fusion_pass.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +8,10 @@ import executorch.backends.cortex_m.ops.operators # noqa: F401 from executorch.backends.arm._passes.quant_args import QuantArgs -from executorch.backends.cortex_m.passes.passes_utils import quantize_val +from executorch.backends.cortex_m.passes.passes_utils import ( + get_activation_bounds, + quantize_val, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -23,7 +26,7 @@ class ActivationFusionPass(ExportPass): """Fuse activations into preceding Cortex-M quantized operators. Supported activation patterns: - q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq + q-> [conv2d, linear, max_pool2d] -> [relu, hardtanh, hardsigmoid, clamp] -> dq Fusing works by clamping the quantized output range (and zero-point when required) of the preceding Cortex-M operator, then removing the activation @@ -37,10 +40,17 @@ class ActivationFusionPass(ExportPass): exir_ops.edge.aten.clamp.default, } + MAX_POOL_OPS = { + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, + } + FUSE_OPS = { exir_ops.edge.aten.linear.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.max_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, } def _get_validated_qparams(self, node, input_node): @@ -63,30 +73,38 @@ def _get_validated_qparams(self, node, input_node): ) return None - match node.target: - case exir_ops.edge.aten.relu.default: - quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) - quantized_max_val = qmax - case exir_ops.edge.aten.hardtanh.default: - quantized_min_val = quantize_val(node.args[1], scale, zp, qmin, qmax) - quantized_max_val = quantize_val(node.args[2], scale, zp, qmin, qmax) - case exir_ops.edge.aten.hardsigmoid.default: - quantized_min_val = quantize_val(0, scale, zp, qmin, qmax) - quantized_max_val = quantize_val(1, scale, zp, qmin, qmax) - case exir_ops.edge.aten.clamp.default: - quantized_min_val = ( - quantize_val(node.args[1], scale, zp, qmin, qmax) - if node.args[1] is not None - else qmin - ) - # Last arg is removed if none, so check length of args here - quantized_max_val = ( - quantize_val(node.args[2], scale, zp, qmin, qmax) - if len(node.args) == 3 - else qmax + bounds = get_activation_bounds(node) + if bounds is None: + logger.warning( + "Cannot fuse activation %s because bounds are not compile-time scalars.", + node.name, + ) + return None + min_val, max_val = bounds + + quantized_min_val = ( + quantize_val(min_val, scale, zp, qmin, qmax) + if min_val is not None + else qmin + ) + quantized_max_val = ( + quantize_val(max_val, scale, zp, qmin, qmax) + if max_val is not None + else qmax + ) + + if input_node.target in self.MAX_POOL_OPS: + if node.target == exir_ops.edge.aten.hardsigmoid.default: + logger.warning( + "Cannot fuse hardsigmoid %s after max_pool2d because max_pool2d requires matching input/output qparams.", + node.name, ) - case _: - raise RuntimeError(f"Unexpected target {node.target}.") + return None + # Max-pool keeps scale and zero-point unchanged and lowers fused + # activation bounds separately, so only qmin/qmax need updating here. + qparams_dict["qmin"] = int(quantized_min_val) + qparams_dict["qmax"] = int(quantized_max_val) + return qparams_dict # If the minimal quantized value is larger than the qmin, it means that the quantized range contains # invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters. diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index 9fef167ef09..074eb6118d0 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -28,6 +28,7 @@ from .convert_to_cortex_m_pass import ConvertToCortexMPass from .decompose_hardswish_pass import DecomposeHardswishPass from .decompose_mean_pass import DecomposeMeanPass +from .quantized_clamp_activation_pass import QuantizedClampActivationPass from .quantized_op_fusion_pass import QuantizedOpFusionPass from .replace_quant_nodes_pass import ReplaceQuantNodesPass @@ -42,6 +43,7 @@ class CortexMPassManager(PassManager): ReplaceScalarWithTensorArgPass, ReplaceQuantNodesPass, ActivationFusionPass, + QuantizedClampActivationPass, DecomposeHardswishPass, QuantizedOpFusionPass, ConvertToCortexMPass, diff --git a/backends/cortex_m/passes/passes_utils.py b/backends/cortex_m/passes/passes_utils.py index a6f68022430..fcbfa301b06 100644 --- a/backends/cortex_m/passes/passes_utils.py +++ b/backends/cortex_m/passes/passes_utils.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import math +from typing import Any import torch @@ -21,6 +22,56 @@ def quantize_val(val, scale, zp, qmin, qmax): return float(min(max(torch.round(torch.Tensor([val / scale + zp])), qmin), qmax)) +def extract_constant_scalar(arg: Any) -> float | None: + if arg is None: + return None + if isinstance(arg, (int, float)): + return float(arg) + if isinstance(arg, Node): + if arg.op == "call_function" and arg.target in { + exir_ops.edge.aten.full_like.default, + exir_ops.edge.aten.full.default, + torch.ops.aten.full_like.default, + torch.ops.aten.full.default, + }: + fill_arg = arg.args[1] if len(arg.args) > 1 else None + return extract_constant_scalar(fill_arg) + val = arg.meta.get("val") + if val is None: + return None + return extract_constant_scalar(val) + return None + + +def get_activation_bounds(node: Node) -> tuple[float | None, float | None] | None: + bounds: tuple[float | None, float | None] + match node.target: + case exir_ops.edge.aten.relu.default | exir_ops.edge.aten.relu_.default: + bounds = (0.0, None) + case exir_ops.edge.aten.hardsigmoid.default: + bounds = (0.0, 1.0) + case exir_ops.edge.aten.hardtanh.default | exir_ops.edge.aten.hardtanh_.default: + bounds = ( + extract_constant_scalar(node.args[1]), + extract_constant_scalar(node.args[2]), + ) + case exir_ops.edge.aten.clamp.default | exir_ops.edge.aten.clamp.Tensor: + bounds = ( + extract_constant_scalar(node.args[1]) if len(node.args) > 1 else None, + extract_constant_scalar(node.args[2]) if len(node.args) > 2 else None, + ) + case _: + return None + + min_val, max_val = bounds + if len(node.args) > 1 and min_val is None and node.args[1] is not None: + return None + if len(node.args) > 2 and max_val is None and node.args[2] is not None: + return None + + return bounds + + def dequantize_per_tensor_cmsis( qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int ) -> torch.Tensor: diff --git a/backends/cortex_m/passes/quantized_clamp_activation_pass.py b/backends/cortex_m/passes/quantized_clamp_activation_pass.py new file mode 100644 index 00000000000..2ba003dbc01 --- /dev/null +++ b/backends/cortex_m/passes/quantized_clamp_activation_pass.py @@ -0,0 +1,129 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Any + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, +) +from executorch.backends.cortex_m.passes.passes_utils import ( + get_activation_bounds, + quantize_val, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_manager import PassResult + +logger = logging.getLogger(__name__) + + +class QuantizedClampActivationPass(ExportPass): + """Canonicalize remaining clamp-like activations on quantized tensors. + + This pass runs after activation fusion, so any remaining relu/hardtanh/clamp + still needs to execute in the quantized domain. It rewrites relu and + hardtanh variants to `aten.clamp.default` and quantizes the clamp bounds so + the portable kernel consumes and produces int8 tensors. + """ + + TARGETS = { + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.relu_.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.hardtanh_.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.clamp.Tensor, + } + + def _get_quantized_bounds( + self, node: Node, qparams_dict: dict[str, Any] + ) -> tuple[int | None, int | None] | None: + qmin = qparams_dict["qmin"] + qmax = qparams_dict["qmax"] + scale = qparams_dict["scale"] + zp = qparams_dict["zp"] + + bounds = get_activation_bounds(node) + if bounds is None: + logger.warning( + "Cannot rewrite %s because bounds are not compile-time scalars.", + node.name, + ) + return None + min_val, max_val = bounds + + quantized_min = ( + int(quantize_val(min_val, scale, zp, qmin, qmax)) + if min_val is not None + else None + ) + quantized_max = ( + int(quantize_val(max_val, scale, zp, qmin, qmax)) + if max_val is not None + else None + ) + return quantized_min, quantized_max + + def _is_quantized_int8_activation(self, node: Node) -> bool: + input_node = node.args[0] if len(node.args) > 0 else None + if not isinstance(input_node, Node): + return False + try: + tensor = get_first_fake_tensor(input_node) + except Exception: + return False + if tensor is None or tensor.dtype != torch.int8: + return False + + try: + qparams_dict = get_output_qparams(node)[0]._asdict() + except (ValueError, KeyError): + logger.warning( + "Cannot quantize clamp bounds for %s without output qparams.", + node.name, + ) + return False + + scale = qparams_dict["scale"] + zp = qparams_dict["zp"] + if not isinstance(scale, float) or not isinstance(zp, int): + logger.warning( + "Cannot quantize clamp bounds for %s with non per-tensor qparams.", + node.name, + ) + return False + + return True + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + + for node in list(graph_module.graph.nodes): + if node.op != "call_function" or node.target not in self.TARGETS: + continue + if not self._is_quantized_int8_activation(node): + continue + + qparams_dict = get_output_qparams(node)[0]._asdict() + + quantized_bounds = self._get_quantized_bounds(node, qparams_dict) + if quantized_bounds is None: + continue + + quantized_min, quantized_max = quantized_bounds + node.target = exir_ops.edge.aten.clamp.default + node.args = (node.args[0], quantized_min, quantized_max) + modified = True + + if modified: + graph_module = super().call(graph_module).graph_module + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index a2fc7d19b21..55d93be5183 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator from typing import Any, Callable import torch @@ -90,10 +91,45 @@ torch.ops.aten.max_pool2d_with_indices.default, } +POOL_FUSED_ACTIVATION_TARGETS = { + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_.default, +} + class CortexMQuantizationConfig(QuantizationConfig): """Configures quantization, while enforcing cortex-m specific constraints.""" + @staticmethod + def _get_shared_pool_input(node: Node | None) -> Node | None: + if node is None or len(node.args) == 0: + return None + + input_node = node.args[0] + if not isinstance(input_node, Node): + return None + + if input_node.target in POOL_SHARE_OUTPUT_TARGETS: + if len(input_node.args) > 0 and isinstance(input_node.args[0], Node): + return input_node.args[0] + return None + + if input_node.target == operator.getitem and len(input_node.args) > 0: + pool_node = input_node.args[0] + if ( + isinstance(pool_node, Node) + and pool_node.target in POOL_SHARE_OUTPUT_TARGETS + and len(pool_node.args) > 0 + and isinstance(pool_node.args[0], Node) + ): + return pool_node.args[0] + + return None + def get_input_act_qspec( self, node: Node | None = None, input_node: Node | None = None ) -> QuantizationSpecBase | None: @@ -121,6 +157,10 @@ def get_output_act_qspec( if isinstance(input_node, Node): return SharedQuantizationSpec((input_node, node)) return super().get_output_act_qspec() + if node is not None and node.target in POOL_FUSED_ACTIVATION_TARGETS: + shared_pool_input = self._get_shared_pool_input(node) + if shared_pool_input is not None: + return SharedQuantizationSpec(shared_pool_input) return super().get_output_act_qspec() def get_weight_qspec(self, node: Node | None = None) -> QuantizationSpecBase | None: diff --git a/backends/cortex_m/quantizer/quantizer_support.py b/backends/cortex_m/quantizer/quantizer_support.py index 2cf0483f74b..3dfbb67638a 100644 --- a/backends/cortex_m/quantizer/quantizer_support.py +++ b/backends/cortex_m/quantizer/quantizer_support.py @@ -122,7 +122,31 @@ POOL_OP_PATTERNS = { (torch.ops.aten.avg_pool2d.default,): CortexMAvgPool2DCheck, (torch.ops.aten.max_pool2d.default,): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.relu.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.hardtanh.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d.default, + torch.ops.aten.clamp.default, + ): CortexMMaxPool2DCheck, (torch.ops.aten.max_pool2d_with_indices.default,): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.relu.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.hardtanh.default, + ): CortexMMaxPool2DCheck, + ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.clamp.default, + ): CortexMMaxPool2DCheck, } BMM_OP_PATTERNS = { diff --git a/backends/cortex_m/test/build_test_runner.sh b/backends/cortex_m/test/build_test_runner.sh index 6ac9aa55e73..2505f83c9da 100755 --- a/backends/cortex_m/test/build_test_runner.sh +++ b/backends/cortex_m/test/build_test_runner.sh @@ -21,6 +21,7 @@ build_root_test_dir="${et_root_dir}/arm_test/arm_semihosting_executor_runner_cor select_ops_list="\ aten::add.out,\ +aten::clamp.out,\ aten::mul.out,\ aten::convolution.out,\ dim_order_ops::_clone_dim_order.out,\ diff --git a/backends/cortex_m/test/misc/test_portable_int8.py b/backends/cortex_m/test/misc/test_portable_int8.py index 82b719230eb..4e3b5f41561 100644 --- a/backends/cortex_m/test/misc/test_portable_int8.py +++ b/backends/cortex_m/test/misc/test_portable_int8.py @@ -662,12 +662,6 @@ def _quantize_and_export( xfails: dict[str, xfail_type] = { "contiguous": "MLETORCH-1863: Contiguos no-op is removed in to-edge, leading to unnecessary Q-DQ-Q-DQ chain.", - "clamp": "MLETORCH-1864: Support non-fused clamp-type activations.", - "clamp_tensor": "MLETORCH-1864: Support non-fused clamp-type activations.", - "hardtanh": "MLETORCH-1864: Support non-fused clamp-type activations.", - "hardtanh_": "MLETORCH-1864: Support non-fused clamp-type activations.", - "relu": "MLETORCH-1864: Support non-fused clamp-type activations.", - "relu_": "MLETORCH-1864: Support non-fused clamp-type activations.", "eq_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", "ne_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", "ge_scalar": "MLETORCH-1865: Properly support flaky scalar comparison ops.", diff --git a/backends/cortex_m/test/models/test_nn_modules.py b/backends/cortex_m/test/models/test_nn_modules.py index 4a92fd578ff..303b481d4bc 100644 --- a/backends/cortex_m/test/models/test_nn_modules.py +++ b/backends/cortex_m/test/models/test_nn_modules.py @@ -1,6 +1,6 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/backends/cortex_m/test/ops/test_activation.py b/backends/cortex_m/test/ops/test_activation.py index 8886a05a84b..0934386d67c 100644 --- a/backends/cortex_m/test/ops/test_activation.py +++ b/backends/cortex_m/test/ops/test_activation.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -398,6 +398,154 @@ def forward(self, x): return torch.clamp(self.linear(x), min=None, max=6.0) +class CortexMStandaloneReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.relu(x) + + +class CortexMStandaloneHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.nn.functional.hardtanh(x, -1.0, 1.0) + + +class CortexMStandaloneClamp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.clamp(x, -1.0, 1.0) + + +class CortexMStandaloneClampTensor(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_full_like_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 4, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 5, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + def forward(self, x): + return torch.ops.aten.clamp.Tensor( + x, torch.full_like(x, -1.0), torch.full_like(x, 1.0) + ) + + +class CortexMMaxPool2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_relu_default"] + + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.pool(x)) + + +class CortexMMaxPool2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_hardtanh_default"] + + def __init__(self, min_val=-0.5, max_val=0.5): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return torch.nn.functional.hardtanh(self.pool(x), self.min_val, self.max_val) + + +class CortexMMaxPool2DClamp(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1, + "executorch_exir_dialects_edge__ops_aten_clamp_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_max_pool2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_clamp_default"] + + def __init__(self, min_val=-0.25, max_val=0.75): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.min_val = min_val + self.max_val = max_val + + def forward(self, x): + return torch.clamp(self.pool(x), self.min_val, self.max_val) + + test_cases = { # Linear + activation tests with various data ranges "linear_relu_small_range": McuTestCase( @@ -509,6 +657,40 @@ def forward(self, x): model=CortexMLinearClamp(in_features=4, out_features=3), example_inputs=(ramp_tensor(-10, 10, (1, 4)),), ), + "standalone_relu": McuTestCase( + model=CortexMStandaloneReLU(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_hardtanh": McuTestCase( + model=CortexMStandaloneHardtanh(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_clamp": McuTestCase( + model=CortexMStandaloneClamp(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "standalone_clamp_tensor": McuTestCase( + model=CortexMStandaloneClampTensor(), + example_inputs=(ramp_tensor(-5, 5, (2, 3, 4, 5)),), + ), + "maxpool_relu": McuTestCase( + model=CortexMMaxPool2DReLU(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "maxpool_hardtanh": McuTestCase( + model=CortexMMaxPool2DHardtanh(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), + "maxpool_clamp": McuTestCase( + model=CortexMMaxPool2DClamp(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 8, 8)).to(memory_format=torch.channels_last), + ), + ), } @@ -520,6 +702,8 @@ def test_dialect_activation(test_case): test_case.model.ops_after_transforms, qtol=1, ) + if hasattr(test_case.model, "ops_after_absent"): + tester.check_not(test_case.model.ops_after_absent) @parametrize("test_case", test_cases) diff --git a/backends/cortex_m/test/ops/test_conv_transpose.py b/backends/cortex_m/test/ops/test_conv_transpose.py index 7a91c5e1b6b..8202e3dc999 100644 --- a/backends/cortex_m/test/ops/test_conv_transpose.py +++ b/backends/cortex_m/test/ops/test_conv_transpose.py @@ -60,6 +60,61 @@ def forward(self, x): return self.conv_transpose(x) +class CortexMConvTranspose2DReLU(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_relu_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_transpose_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_relu_default"] + + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + 4, 2, kernel_size=3, stride=2, padding=1, bias=True + ) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.conv_transpose(x)) + + +class CortexMConvTranspose2DHardtanh(torch.nn.Module): + ops_before_transforms = { + "executorch_exir_dialects_edge__ops_aten_convolution_default": 1, + "executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_channel_default": 2, + } + + ops_after_transforms = { + "executorch_exir_dialects_edge__ops_cortex_m_quantized_transpose_conv2d_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1, + } + + ops_after_absent = ["executorch_exir_dialects_edge__ops_aten_hardtanh_default"] + + def __init__(self): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + 4, 2, kernel_size=3, stride=2, padding=1, bias=True + ) + + def forward(self, x): + return torch.nn.functional.hardtanh(self.conv_transpose(x), -0.5, 0.5) + + # Test cases covering various configurations test_cases = { # Basic test case @@ -123,6 +178,18 @@ def forward(self, x): ramp_tensor(0, 50, (1, 5, 4, 4)).to(memory_format=torch.channels_last), ), ), + "conv_transpose2d_relu": McuTestCase( + model=CortexMConvTranspose2DReLU(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 4, 4)).to(memory_format=torch.channels_last), + ), + ), + "conv_transpose2d_hardtanh": McuTestCase( + model=CortexMConvTranspose2DHardtanh(), + example_inputs=( + ramp_tensor(-10, 10, (1, 4, 4, 4)).to(memory_format=torch.channels_last), + ), + ), # Dilation variation "conv_transpose2d_dilation_2": McuTestCase( model=CortexMConvTranspose2D(2, 4, kernel_size=3, dilation=2), @@ -244,12 +311,14 @@ def test_dialect_conv_transpose2d(test_case): test_case.model.ops_after_transforms, qtol=1, ) + if hasattr(test_case.model, "ops_after_absent"): + tester.check_not(test_case.model.ops_after_absent) -# Implementation xfails: empty because unsupported configurations are now -# rejected at AOT time by the quantizer filter, so they fall back to portable -# ops and work correctly. Only xfails_dialect needs to track these. -xfails_implementation: dict[str, xfail_type] = {} +xfails_implementation: dict[str, xfail_type] = { + "conv_transpose2d_relu": "Fused transpose-conv + relu lowers correctly but current implementation is numerically incorrect.", + "conv_transpose2d_hardtanh": "Fused transpose-conv + hardtanh lowers correctly but current implementation is numerically incorrect.", +} @parametrize("test_case", test_cases, xfails=xfails_implementation)