Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 93 additions & 34 deletions backends/cortex_m/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
from collections import defaultdict
from typing import Any, Callable, cast, Iterator, List, Optional

Expand Down Expand Up @@ -40,6 +41,16 @@
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY

logger = logging.getLogger(__name__)


def has_float_output(node: Node) -> bool:
meta_val = node.meta.get("val", None)
if isinstance(meta_val, torch.Tensor):
return meta_val.dtype.is_floating_point

return False


def mark_node_as_annotated(
node: Node,
Expand Down Expand Up @@ -356,6 +367,9 @@ def annotate_match(
), f"{self.__class__.__name__} expected 0 params, 1 params (weight) or 2 params (weight, bias), but got {len(params)} for node {node}."

for input_node in node.all_input_nodes:
# Observers only work on floating point tensors, so make sure to skip other dtypes
if not has_float_output(input_node):
continue
if self.is_weight(input_node, params, model):
input_qspec_map[input_node] = config.weight if config else None
elif self.is_bias(input_node, params, model):
Expand Down Expand Up @@ -487,56 +501,101 @@ def __init__(self, targets: Optional[List[OpOverload]] = None) -> None:
def _is_annotated(self, node: Node) -> bool:
return Q_ANNOTATION_KEY in node.meta

def _annotate_shared_cluster(self, root_node: Node) -> None:
def _get_input_nodes_with_float_output(self, node: Node) -> List[Node]:
# Observers only work on floating point tensors, so make sure to skip other dtypes
return [n for n in node.all_input_nodes if has_float_output(n)]

def _get_user_nodes_with_float_input(self, node: Node) -> List[Node]:
# Observers only work on floating point tensors, so make sure to skip other dtypes
return [n for n in node.users.keys() if has_float_output(node)]

def _get_shared_clique(self, root_node: Node) -> set[Node]:
"""
Finds a cluster of unannotated nodes starting in root_node and annotates them with a common
SharedQuantizationSpec.
Finds a cluster of nodes with targets in self.targets, starting in root_node.
"""

shared_nodes = set()
leaf_nodes = set()
bfs_queue = [root_node]
adjacent_qspecs = set()

while bfs_queue:
node = bfs_queue.pop(0)
shared_nodes.add(node)

if self._is_annotated(node):
leaf_nodes.add(node)
continue
if node.op == "get_attr":
continue
# Neighbours may either be other shared nodes, annotated nodes, or non-annotated (float) nodes.
for input_node in self._get_input_nodes_with_float_output(node):
if input_node.target in self.targets and input_node not in shared_nodes:
if not self._is_annotated(input_node):
bfs_queue.append(input_node)
if self._is_annotated(input_node):
output_qspec = input_node.meta.get(
Q_ANNOTATION_KEY, None
).output_qspec
adjacent_qspecs.add(output_qspec)

for output_node in self._get_user_nodes_with_float_input(node):
if (
output_node.target in self.targets
and output_node not in shared_nodes
):
if not self._is_annotated(output_node):
bfs_queue.append(output_node)
if self._is_annotated(output_node):
input_qspec = output_node.meta.get(
Q_ANNOTATION_KEY, None
).input_qspec_map[node]
adjacent_qspecs.add(input_qspec)

return shared_nodes, adjacent_qspecs

if node.target not in self.targets:
raise NotImplementedError(
(
f"{SharedQspecQuantizer.__name__} found unannoted node '{node.name}' in neighbour_nodes "
"which is not in the supported target list. This might be the case either because:\n"
"1) The op should have shared qspec but is not in the target list. "
"In this case, try modifying the list using the targets field in the initializer.\n"
"2) The op should not be quantized, which is not currently supported by the SharedQspecQuantizer."
)
)
def _annotate_shared_cluster(self, root_node: Node) -> None:
"""
Finds a cluster of unannotated nodes starting in root_node and annotates them with a common
SharedQuantizationSpec.
"""

shared_nodes.add(node)
neighbour_nodes = list(node.all_input_nodes) + list(node.users)
for n in neighbour_nodes:
if n not in shared_nodes:
bfs_queue.append(n)
shared_nodes, adjacent_qspecs = self._get_shared_clique(root_node)

# The selection of root node for the shared_qspec is important for
# torchao.quantization.pt2e.prepare._create_obs_or_fq_from_qspec:
# 1. For regular QuantizationSpecs, it creates a new observer
# 2. For SharedQuantizationSpecs, it returns the observer created for it's root node
# 3. It handles nodes in the order they appear in graph.nodes
# This means that the root node of the shared group needs to be the first annotated node that appears in graph.nodes.
shared_root_node = next(n for n in root_node.graph.nodes if n in leaf_nodes)
shared_qspec = SharedQuantizationSpec(shared_root_node)

for node in shared_nodes:
input_qspec_map: dict[Node, Optional[QuantizationSpec]] = {
n: shared_qspec for n in node.all_input_nodes
}
mark_node_as_annotated(node, input_qspec_map, shared_qspec)
# This means that we need to make sure that the root node of the shared_qspec
# has an input node with a quantization spec, so that an observer is created.

if len(adjacent_qspecs) == 1:
root_node_first_input = self._get_input_nodes_with_float_output(root_node)[
0
]

# Make all nodes share qspec with the root node's first input
shared_qspec = SharedQuantizationSpec((root_node_first_input, root_node))
for node in shared_nodes:
input_qspec_map: dict[Node, Optional[QuantizationSpec]] = {
n: shared_qspec
for n in self._get_input_nodes_with_float_output(node)
}
if len(self._get_user_nodes_with_float_input(node)) == 0:
output_qspec = None
else:
output_qspec = shared_qspec
mark_node_as_annotated(node, input_qspec_map, output_qspec)

# Force the root qspec to be the adjacent spec
root_node.meta[Q_ANNOTATION_KEY].input_qspec_map[
root_node_first_input
] = adjacent_qspecs.pop()

elif len(adjacent_qspecs) == 0:
logger.warning(
"SharedQspecQuantizer found a cluster of supported ops surrounded by no quantized ops - leaving nodes unquantized."
)
return
else:
logger.warning(
"SharedQspecQuantizer found a cluster of supported ops surrounded by multiple different qspecs - leaving nodes unquantized."
)
return

def annotate(self, model: GraphModule) -> None:
"""
Expand Down
55 changes: 49 additions & 6 deletions backends/cortex_m/test/misc/test_quantization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -277,6 +277,27 @@ def forward(self, x, y):
return torch.clone(x - y)


class SharedQspecCompetingQspecs(torch.nn.Module):
ops_before_transforms = {}
ops_after_transforms = {}

def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 1)

def forward(self, x):
return torch.cat([self.conv(x), x], dim=1)


class SharedQspecNoQspecs(torch.nn.Module):
ops_before_transforms = {}
ops_after_transforms = {}

def forward(self, x):
z = torch.clone(x - x)
return z - z


test_cases = {
"multiple_clusters": McuTestCase(
SharedQspecMulipleClusters(),
Expand Down Expand Up @@ -326,15 +347,10 @@ def forward(self, x, y):
SharedQspecManyForks(),
(ramp_tensor(-20, 2, (4, 4)),),
),
"non-quantized_op": McuTestCase(
SharedQspecSub(),
(ramp_tensor(0, 10, (5, 5)), ramp_tensor(0, 1, (5, 5))),
),
}

xfails = {
"surrounded_quantized_op_constant": "Numerical error since the add is forced to have non-correct qparams.",
"non-quantized_op": "Non-quantized ops are not currently supported in SharedQspecQuantizer.",
}


Expand All @@ -357,3 +373,30 @@ def test_shared_qspec_quantizer(test_case):
continue

assert get_first_fake_tensor(node).dtype == torch.int8, f"{node.name}"


float_test_cases = {
"non-quantized_op": McuTestCase(
SharedQspecSub(),
(ramp_tensor(0, 10, (5, 5)), ramp_tensor(0, 1, (5, 5))),
),
"competing_qspecs": McuTestCase(
SharedQspecCompetingQspecs(),
(ramp_tensor(0, 10, (1, 3, 5, 5)).to(memory_format=torch.channels_last),),
),
"no_qspecs": McuTestCase(
SharedQspecNoQspecs(),
(ramp_tensor(0, 10, (1, 3, 5, 5)),),
),
}


@parametrize("test_case", float_test_cases)
def test_shared_qspec_quantizer_no_qspecs(test_case):
"""
Test that ops which does not change dynamic range are able to use int8 portable kernels.
"""
tester = CortexMTester(test_case.model, test_case.example_inputs)
tester.test_dialect(
test_case.model.ops_before_transforms, test_case.model.ops_after_transforms
)
Loading