diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index c9b42bd15e5..515803f0930 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import logging from collections import defaultdict from typing import Any, Callable, cast, Iterator, List, Optional @@ -40,6 +41,16 @@ ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +logger = logging.getLogger(__name__) + + +def has_float_output(node: Node) -> bool: + meta_val = node.meta.get("val", None) + if isinstance(meta_val, torch.Tensor): + return meta_val.dtype.is_floating_point + + return False + def mark_node_as_annotated( node: Node, @@ -356,6 +367,9 @@ def annotate_match( ), f"{self.__class__.__name__} expected 0 params, 1 params (weight) or 2 params (weight, bias), but got {len(params)} for node {node}." for input_node in node.all_input_nodes: + # Observers only work on floating point tensors, so make sure to skip other dtypes + if not has_float_output(input_node): + continue if self.is_weight(input_node, params, model): input_qspec_map[input_node] = config.weight if config else None elif self.is_bias(input_node, params, model): @@ -487,56 +501,101 @@ def __init__(self, targets: Optional[List[OpOverload]] = None) -> None: def _is_annotated(self, node: Node) -> bool: return Q_ANNOTATION_KEY in node.meta - def _annotate_shared_cluster(self, root_node: Node) -> None: + def _get_input_nodes_with_float_output(self, node: Node) -> List[Node]: + # Observers only work on floating point tensors, so make sure to skip other dtypes + return [n for n in node.all_input_nodes if has_float_output(n)] + + def _get_user_nodes_with_float_input(self, node: Node) -> List[Node]: + # Observers only work on floating point tensors, so make sure to skip other dtypes + return [n for n in node.users.keys() if has_float_output(node)] + + def _get_shared_clique(self, root_node: Node) -> set[Node]: """ - Finds a cluster of unannotated nodes starting in root_node and annotates them with a common - SharedQuantizationSpec. + Finds a cluster of nodes with targets in self.targets, starting in root_node. """ - shared_nodes = set() - leaf_nodes = set() bfs_queue = [root_node] + adjacent_qspecs = set() while bfs_queue: node = bfs_queue.pop(0) + shared_nodes.add(node) - if self._is_annotated(node): - leaf_nodes.add(node) - continue - if node.op == "get_attr": - continue + # Neighbours may either be other shared nodes, annotated nodes, or non-annotated (float) nodes. + for input_node in self._get_input_nodes_with_float_output(node): + if input_node.target in self.targets and input_node not in shared_nodes: + if not self._is_annotated(input_node): + bfs_queue.append(input_node) + if self._is_annotated(input_node): + output_qspec = input_node.meta.get( + Q_ANNOTATION_KEY, None + ).output_qspec + adjacent_qspecs.add(output_qspec) + + for output_node in self._get_user_nodes_with_float_input(node): + if ( + output_node.target in self.targets + and output_node not in shared_nodes + ): + if not self._is_annotated(output_node): + bfs_queue.append(output_node) + if self._is_annotated(output_node): + input_qspec = output_node.meta.get( + Q_ANNOTATION_KEY, None + ).input_qspec_map[node] + adjacent_qspecs.add(input_qspec) + + return shared_nodes, adjacent_qspecs - if node.target not in self.targets: - raise NotImplementedError( - ( - f"{SharedQspecQuantizer.__name__} found unannoted node '{node.name}' in neighbour_nodes " - "which is not in the supported target list. This might be the case either because:\n" - "1) The op should have shared qspec but is not in the target list. " - "In this case, try modifying the list using the targets field in the initializer.\n" - "2) The op should not be quantized, which is not currently supported by the SharedQspecQuantizer." - ) - ) + def _annotate_shared_cluster(self, root_node: Node) -> None: + """ + Finds a cluster of unannotated nodes starting in root_node and annotates them with a common + SharedQuantizationSpec. + """ - shared_nodes.add(node) - neighbour_nodes = list(node.all_input_nodes) + list(node.users) - for n in neighbour_nodes: - if n not in shared_nodes: - bfs_queue.append(n) + shared_nodes, adjacent_qspecs = self._get_shared_clique(root_node) # The selection of root node for the shared_qspec is important for # torchao.quantization.pt2e.prepare._create_obs_or_fq_from_qspec: # 1. For regular QuantizationSpecs, it creates a new observer # 2. For SharedQuantizationSpecs, it returns the observer created for it's root node # 3. It handles nodes in the order they appear in graph.nodes - # This means that the root node of the shared group needs to be the first annotated node that appears in graph.nodes. - shared_root_node = next(n for n in root_node.graph.nodes if n in leaf_nodes) - shared_qspec = SharedQuantizationSpec(shared_root_node) - - for node in shared_nodes: - input_qspec_map: dict[Node, Optional[QuantizationSpec]] = { - n: shared_qspec for n in node.all_input_nodes - } - mark_node_as_annotated(node, input_qspec_map, shared_qspec) + # This means that we need to make sure that the root node of the shared_qspec + # has an input node with a quantization spec, so that an observer is created. + + if len(adjacent_qspecs) == 1: + root_node_first_input = self._get_input_nodes_with_float_output(root_node)[ + 0 + ] + + # Make all nodes share qspec with the root node's first input + shared_qspec = SharedQuantizationSpec((root_node_first_input, root_node)) + for node in shared_nodes: + input_qspec_map: dict[Node, Optional[QuantizationSpec]] = { + n: shared_qspec + for n in self._get_input_nodes_with_float_output(node) + } + if len(self._get_user_nodes_with_float_input(node)) == 0: + output_qspec = None + else: + output_qspec = shared_qspec + mark_node_as_annotated(node, input_qspec_map, output_qspec) + + # Force the root qspec to be the adjacent spec + root_node.meta[Q_ANNOTATION_KEY].input_qspec_map[ + root_node_first_input + ] = adjacent_qspecs.pop() + + elif len(adjacent_qspecs) == 0: + logger.warning( + "SharedQspecQuantizer found a cluster of supported ops surrounded by no quantized ops - leaving nodes unquantized." + ) + return + else: + logger.warning( + "SharedQspecQuantizer found a cluster of supported ops surrounded by multiple different qspecs - leaving nodes unquantized." + ) + return def annotate(self, model: GraphModule) -> None: """ diff --git a/backends/cortex_m/test/misc/test_quantization.py b/backends/cortex_m/test/misc/test_quantization.py index d4f84e4f075..6532449fb55 100644 --- a/backends/cortex_m/test/misc/test_quantization.py +++ b/backends/cortex_m/test/misc/test_quantization.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -277,6 +277,27 @@ def forward(self, x, y): return torch.clone(x - y) +class SharedQspecCompetingQspecs(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 1) + + def forward(self, x): + return torch.cat([self.conv(x), x], dim=1) + + +class SharedQspecNoQspecs(torch.nn.Module): + ops_before_transforms = {} + ops_after_transforms = {} + + def forward(self, x): + z = torch.clone(x - x) + return z - z + + test_cases = { "multiple_clusters": McuTestCase( SharedQspecMulipleClusters(), @@ -326,15 +347,10 @@ def forward(self, x, y): SharedQspecManyForks(), (ramp_tensor(-20, 2, (4, 4)),), ), - "non-quantized_op": McuTestCase( - SharedQspecSub(), - (ramp_tensor(0, 10, (5, 5)), ramp_tensor(0, 1, (5, 5))), - ), } xfails = { "surrounded_quantized_op_constant": "Numerical error since the add is forced to have non-correct qparams.", - "non-quantized_op": "Non-quantized ops are not currently supported in SharedQspecQuantizer.", } @@ -357,3 +373,30 @@ def test_shared_qspec_quantizer(test_case): continue assert get_first_fake_tensor(node).dtype == torch.int8, f"{node.name}" + + +float_test_cases = { + "non-quantized_op": McuTestCase( + SharedQspecSub(), + (ramp_tensor(0, 10, (5, 5)), ramp_tensor(0, 1, (5, 5))), + ), + "competing_qspecs": McuTestCase( + SharedQspecCompetingQspecs(), + (ramp_tensor(0, 10, (1, 3, 5, 5)).to(memory_format=torch.channels_last),), + ), + "no_qspecs": McuTestCase( + SharedQspecNoQspecs(), + (ramp_tensor(0, 10, (1, 3, 5, 5)),), + ), +} + + +@parametrize("test_case", float_test_cases) +def test_shared_qspec_quantizer_no_qspecs(test_case): + """ + Test that ops which does not change dynamic range are able to use int8 portable kernels. + """ + tester = CortexMTester(test_case.model, test_case.example_inputs) + tester.test_dialect( + test_case.model.ops_before_transforms, test_case.model.ops_after_transforms + )