diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index 29d77db2094..bf7f79127a0 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -8,7 +8,8 @@ import logging import operator -from typing import Any, Optional, Union +from collections.abc import Mapping, Sequence +from typing import Any, cast, Optional, Union import torch from torch._inductor.decomposition import remove_decompositions @@ -301,23 +302,27 @@ def __init__( "Warning: Using pre-quantized inputs. This should only be done when calibration has been confirmed." "Incorrect quantization parameters can lead to significant accuracy degradation." ) - if isinstance(input_args, list): - self.quant_args = extract_input_quant_params_from_graph(module, input_args) - elif isinstance(input_args, dict): + if isinstance(input_args, Sequence) and not isinstance( + input_args, (str, bytes) + ): + self.quant_args = extract_input_quant_params_from_graph( + module, list(input_args) + ) + elif isinstance(input_args, Mapping): # dict[int, QuantArgs] — use directly # dict[int, list[str]] — extract quant params from graph, keyed by input index first_value = next(iter(input_args.values()), None) if ( - isinstance(first_value, (list, tuple)) + isinstance(first_value, (list, tuple, Sequence)) + and not isinstance(first_value, (str, bytes)) and first_value and isinstance(first_value[0], str) ): # Values are lists of node names: extract quant params and map # to the caller-specified input indices. for input_idx, node_names in input_args.items(): - assert isinstance(node_names, list) extracted = extract_input_quant_params_from_graph( - module, node_names + module, list(cast(Sequence[str], node_names)) ) # Use the first extracted quant params for this input index. if extracted: @@ -430,6 +435,7 @@ def _get_transparent_ops() -> set[Any]: torch.ops.aten.view.default, torch.ops.aten.reshape.default, torch.ops.aten.split.Tensor, + torch.ops.aten.chunk.default, torch.ops.aten.slice_copy.Tensor, torch.ops.aten.permute_copy.default, torch.ops.aten.permute.default,