Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cortex_m/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa
from .decompose_hardswish_pass import DecomposeHardswishPass # noqa
from .decompose_mean_pass import DecomposeMeanPass # noqa
from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa
from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa
from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa
from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip
70 changes: 44 additions & 26 deletions backends/cortex_m/passes/activation_fusion_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -8,7 +8,10 @@

import executorch.backends.cortex_m.ops.operators # noqa: F401
from executorch.backends.arm._passes.quant_args import QuantArgs
from executorch.backends.cortex_m.passes.passes_utils import quantize_val
from executorch.backends.cortex_m.passes.passes_utils import (
get_activation_bounds,
quantize_val,
)

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand All @@ -23,7 +26,7 @@ class ActivationFusionPass(ExportPass):
"""Fuse activations into preceding Cortex-M quantized operators.

Supported activation patterns:
q-> [conv2d, linear] -> [relu, hardtanh, hardsigmoid] -> dq
q-> [conv2d, linear, max_pool2d] -> [relu, hardtanh, hardsigmoid, clamp] -> dq

Fusing works by clamping the quantized output range (and zero-point when
required) of the preceding Cortex-M operator, then removing the activation
Expand All @@ -37,10 +40,17 @@ class ActivationFusionPass(ExportPass):
exir_ops.edge.aten.clamp.default,
}

MAX_POOL_OPS = {
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
}

FUSE_OPS = {
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
}

def _get_validated_qparams(self, node, input_node):
Expand All @@ -63,30 +73,38 @@ def _get_validated_qparams(self, node, input_node):
)
return None

match node.target:
case exir_ops.edge.aten.relu.default:
quantized_min_val = quantize_val(0, scale, zp, qmin, qmax)
quantized_max_val = qmax
case exir_ops.edge.aten.hardtanh.default:
quantized_min_val = quantize_val(node.args[1], scale, zp, qmin, qmax)
quantized_max_val = quantize_val(node.args[2], scale, zp, qmin, qmax)
case exir_ops.edge.aten.hardsigmoid.default:
quantized_min_val = quantize_val(0, scale, zp, qmin, qmax)
quantized_max_val = quantize_val(1, scale, zp, qmin, qmax)
case exir_ops.edge.aten.clamp.default:
quantized_min_val = (
quantize_val(node.args[1], scale, zp, qmin, qmax)
if node.args[1] is not None
else qmin
)
# Last arg is removed if none, so check length of args here
quantized_max_val = (
quantize_val(node.args[2], scale, zp, qmin, qmax)
if len(node.args) == 3
else qmax
bounds = get_activation_bounds(node)
if bounds is None:
logger.warning(
"Cannot fuse activation %s because bounds are not compile-time scalars.",
node.name,
)
return None
min_val, max_val = bounds

quantized_min_val = (
quantize_val(min_val, scale, zp, qmin, qmax)
if min_val is not None
else qmin
)
quantized_max_val = (
quantize_val(max_val, scale, zp, qmin, qmax)
if max_val is not None
else qmax
)

if input_node.target in self.MAX_POOL_OPS:
if node.target == exir_ops.edge.aten.hardsigmoid.default:
logger.warning(
"Cannot fuse hardsigmoid %s after max_pool2d because max_pool2d requires matching input/output qparams.",
node.name,
)
case _:
raise RuntimeError(f"Unexpected target {node.target}.")
return None
# Max-pool keeps scale and zero-point unchanged and lowers fused
# activation bounds separately, so only qmin/qmax need updating here.
qparams_dict["qmin"] = int(quantized_min_val)
qparams_dict["qmax"] = int(quantized_max_val)
return qparams_dict

# If the minimal quantized value is larger than the qmin, it means that the quantized range contains
# invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters.
Expand Down
2 changes: 2 additions & 0 deletions backends/cortex_m/passes/cortex_m_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .convert_to_cortex_m_pass import ConvertToCortexMPass
from .decompose_hardswish_pass import DecomposeHardswishPass
from .decompose_mean_pass import DecomposeMeanPass
from .quantized_clamp_activation_pass import QuantizedClampActivationPass
from .quantized_op_fusion_pass import QuantizedOpFusionPass
from .replace_quant_nodes_pass import ReplaceQuantNodesPass

Expand All @@ -42,6 +43,7 @@ class CortexMPassManager(PassManager):
ReplaceScalarWithTensorArgPass,
ReplaceQuantNodesPass,
ActivationFusionPass,
QuantizedClampActivationPass,
DecomposeHardswishPass,
QuantizedOpFusionPass,
ConvertToCortexMPass,
Expand Down
51 changes: 51 additions & 0 deletions backends/cortex_m/passes/passes_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import math
from typing import Any

import torch

Expand All @@ -21,6 +22,56 @@ def quantize_val(val, scale, zp, qmin, qmax):
return float(min(max(torch.round(torch.Tensor([val / scale + zp])), qmin), qmax))


def extract_constant_scalar(arg: Any) -> float | None:
if arg is None:
return None
if isinstance(arg, (int, float)):
return float(arg)
if isinstance(arg, Node):
if arg.op == "call_function" and arg.target in {
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.full.default,
torch.ops.aten.full_like.default,
torch.ops.aten.full.default,
}:
fill_arg = arg.args[1] if len(arg.args) > 1 else None
return extract_constant_scalar(fill_arg)
val = arg.meta.get("val")
if val is None:
return None
return extract_constant_scalar(val)
return None


def get_activation_bounds(node: Node) -> tuple[float | None, float | None] | None:
bounds: tuple[float | None, float | None]
match node.target:
case exir_ops.edge.aten.relu.default | exir_ops.edge.aten.relu_.default:
bounds = (0.0, None)
case exir_ops.edge.aten.hardsigmoid.default:
bounds = (0.0, 1.0)
case exir_ops.edge.aten.hardtanh.default | exir_ops.edge.aten.hardtanh_.default:
bounds = (
extract_constant_scalar(node.args[1]),
extract_constant_scalar(node.args[2]),
)
case exir_ops.edge.aten.clamp.default | exir_ops.edge.aten.clamp.Tensor:
bounds = (
extract_constant_scalar(node.args[1]) if len(node.args) > 1 else None,
extract_constant_scalar(node.args[2]) if len(node.args) > 2 else None,
)
case _:
return None

min_val, max_val = bounds
if len(node.args) > 1 and min_val is None and node.args[1] is not None:
return None
if len(node.args) > 2 and max_val is None and node.args[2] is not None:
return None

return bounds


def dequantize_per_tensor_cmsis(
qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int
) -> torch.Tensor:
Expand Down
129 changes: 129 additions & 0 deletions backends/cortex_m/passes/quantized_clamp_activation_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_output_qparams,
)
from executorch.backends.cortex_m.passes.passes_utils import (
get_activation_bounds,
quantize_val,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_manager import PassResult

logger = logging.getLogger(__name__)


class QuantizedClampActivationPass(ExportPass):
"""Canonicalize remaining clamp-like activations on quantized tensors.

This pass runs after activation fusion, so any remaining relu/hardtanh/clamp
still needs to execute in the quantized domain. It rewrites relu and
hardtanh variants to `aten.clamp.default` and quantizes the clamp bounds so
the portable kernel consumes and produces int8 tensors.
"""

TARGETS = {
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.relu_.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.hardtanh_.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.clamp.Tensor,
}

def _get_quantized_bounds(
self, node: Node, qparams_dict: dict[str, Any]
) -> tuple[int | None, int | None] | None:
qmin = qparams_dict["qmin"]
qmax = qparams_dict["qmax"]
scale = qparams_dict["scale"]
zp = qparams_dict["zp"]

bounds = get_activation_bounds(node)
if bounds is None:
logger.warning(
"Cannot rewrite %s because bounds are not compile-time scalars.",
node.name,
)
return None
min_val, max_val = bounds

quantized_min = (
int(quantize_val(min_val, scale, zp, qmin, qmax))
if min_val is not None
else None
)
quantized_max = (
int(quantize_val(max_val, scale, zp, qmin, qmax))
if max_val is not None
else None
)
return quantized_min, quantized_max

def _is_quantized_int8_activation(self, node: Node) -> bool:
input_node = node.args[0] if len(node.args) > 0 else None
if not isinstance(input_node, Node):
return False
try:
tensor = get_first_fake_tensor(input_node)
except Exception:
return False
if tensor is None or tensor.dtype != torch.int8:
return False

try:
qparams_dict = get_output_qparams(node)[0]._asdict()
except (ValueError, KeyError):
logger.warning(
"Cannot quantize clamp bounds for %s without output qparams.",
node.name,
)
return False

scale = qparams_dict["scale"]
zp = qparams_dict["zp"]
if not isinstance(scale, float) or not isinstance(zp, int):
logger.warning(
"Cannot quantize clamp bounds for %s with non per-tensor qparams.",
node.name,
)
return False

return True

def call(self, graph_module: GraphModule) -> PassResult:
modified = False

for node in list(graph_module.graph.nodes):
if node.op != "call_function" or node.target not in self.TARGETS:
continue
if not self._is_quantized_int8_activation(node):
continue

qparams_dict = get_output_qparams(node)[0]._asdict()

quantized_bounds = self._get_quantized_bounds(node, qparams_dict)
if quantized_bounds is None:
continue

quantized_min, quantized_max = quantized_bounds
node.target = exir_ops.edge.aten.clamp.default
node.args = (node.args[0], quantized_min, quantized_max)
modified = True

if modified:
graph_module = super().call(graph_module).graph_module
graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return PassResult(graph_module, modified)
40 changes: 40 additions & 0 deletions backends/cortex_m/quantizer/quantization_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -90,10 +91,45 @@
torch.ops.aten.max_pool2d_with_indices.default,
}

POOL_FUSED_ACTIVATION_TARGETS = {
torch.ops.aten.relu.default,
torch.ops.aten.relu_.default,
torch.ops.aten.hardtanh.default,
torch.ops.aten.hardtanh_.default,
torch.ops.aten.clamp.default,
torch.ops.aten.clamp_.default,
}


class CortexMQuantizationConfig(QuantizationConfig):
"""Configures quantization, while enforcing cortex-m specific constraints."""

@staticmethod
def _get_shared_pool_input(node: Node | None) -> Node | None:
if node is None or len(node.args) == 0:
return None

input_node = node.args[0]
if not isinstance(input_node, Node):
return None

if input_node.target in POOL_SHARE_OUTPUT_TARGETS:
if len(input_node.args) > 0 and isinstance(input_node.args[0], Node):
return input_node.args[0]
return None

if input_node.target == operator.getitem and len(input_node.args) > 0:
pool_node = input_node.args[0]
if (
isinstance(pool_node, Node)
and pool_node.target in POOL_SHARE_OUTPUT_TARGETS
and len(pool_node.args) > 0
and isinstance(pool_node.args[0], Node)
):
return pool_node.args[0]

return None

def get_input_act_qspec(
self, node: Node | None = None, input_node: Node | None = None
) -> QuantizationSpecBase | None:
Expand Down Expand Up @@ -121,6 +157,10 @@ def get_output_act_qspec(
if isinstance(input_node, Node):
return SharedQuantizationSpec((input_node, node))
return super().get_output_act_qspec()
if node is not None and node.target in POOL_FUSED_ACTIVATION_TARGETS:
shared_pool_input = self._get_shared_pool_input(node)
if shared_pool_input is not None:
return SharedQuantizationSpec(shared_pool_input)
return super().get_output_act_qspec()

def get_weight_qspec(self, node: Node | None = None) -> QuantizationSpecBase | None:
Expand Down
Loading
Loading