diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 628110411e2..7fa2ac6f224 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -55,6 +55,7 @@ def trace( inputs: tuple[object, ...], dump_graphs: bool = False, ops_to_keep: Optional[list[torch._ops.OpOverload]] = None, + is_qat: bool = False, ) -> ExportedProgram: """ Trace the model with export and return an ExportedProgram. @@ -62,7 +63,7 @@ def trace( if ops_to_keep is None: ops_to_keep = [] program = trace_fn( - model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep + model, inputs, is_qat=is_qat, strict=True, ops_to_keep=ops_to_keep ) if dump_graphs: @@ -77,6 +78,7 @@ def prepare_pt2( inputs: tuple[object, ...], quantizer: CadenceQuantizer, dump_graphs: bool = False, + is_qat: bool = False, ) -> torch.fx.GraphModule: """ Trace and Prepare a model using the given quantizer. @@ -89,10 +91,10 @@ def prepare_pt2( ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition() traced_program = trace( - model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep + model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat ) prepared_program = prepare_traced_pt2( - traced_program, quantizer, dump_graphs=dump_graphs + traced_program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat ) return prepared_program @@ -102,6 +104,7 @@ def prepare_traced_pt2( program: ExportedProgram, quantizer: CadenceQuantizer, dump_graphs: bool = False, + is_qat: bool = False, ) -> torch.fx.GraphModule: """ Prepare a model using the given quantizer. @@ -112,7 +115,7 @@ def prepare_traced_pt2( Returns a GraphModule with the prepared model. """ - prepared_model = prepare_fn(program, quantizer, is_qat=False) + prepared_model = prepare_fn(program, quantizer, is_qat=is_qat) if dump_graphs: logging.info("Graph after preparation:") diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 9399efe632a..071e8f91b13 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -41,7 +41,12 @@ no_outside_users, ) from torch import fx -from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, + HistogramObserver, + MinMaxObserver, +) from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, DerivedQuantizationSpec, @@ -106,6 +111,47 @@ None, ) +act_qat_qspec_asym8s = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize, +) + +wgt_qat_qspec_asym8s = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver), +) + +wgt_qat_qspec_sym8s = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver), +) + +qconfig_A8W8_qat = QuantizationConfig( + act_qat_qspec_asym8s, + act_qat_qspec_asym8s, + wgt_qat_qspec_asym8s, + None, +) + +qconfig_A8W8sym_qat = QuantizationConfig( + act_qat_qspec_asym8s, + act_qat_qspec_asym8s, + wgt_qat_qspec_sym8s, + None, +) + qconfig_A16 = QuantizationConfig( act_qspec_asym16s, act_qspec_asym16s, @@ -221,18 +267,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]: return [] -def get_cadence_default_quantizers() -> List[Quantizer]: +def get_cadence_default_quantizers(is_qat: bool = False) -> List[Quantizer]: + a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8 + a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym return [ - CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8), - CadenceAtenQuantizer(BmmPattern(), qconfig_A8W8), - CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8W8sym), - CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym), - CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8), - CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8), - CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8), - CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8), - CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8), - CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8), + CadenceAtenQuantizer(AddmmPattern(), a8w8), + CadenceAtenQuantizer(BmmPattern(), a8w8), + CadenceAtenQuantizer(Conv1dPattern(), a8w8sym), + CadenceAtenQuantizer(Conv2dPattern(), a8w8sym), + CadenceAtenQuantizer(LinearPattern(), a8w8), + CadenceAtenQuantizer(MatmulPattern(), a8w8), + CadenceAtenQuantizer(MaxPool2dPattern(), a8w8), + CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), a8w8), + CadenceAtenQuantizer(ReluPattern0(), a8w8), + CadenceAtenQuantizer(ReluPattern1(), a8w8), ] @@ -270,9 +318,13 @@ class CadenceDefaultQuantizer(CadenceQuantizer): Default quantizer for Cadence backend. """ - def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + def __init__( + self, + quantizers: Optional[list[Quantizer]] = None, + is_qat: bool = False, + ) -> None: if quantizers is None: - quantizers = get_cadence_default_quantizers() + quantizers = get_cadence_default_quantizers(is_qat=is_qat) super().__init__(quantizers) @@ -314,11 +366,16 @@ class CadenceWakeWordQuantizer(CadenceQuantizer): Quantizer for WakeWord, including add and cat """ - def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + def __init__( + self, + quantizers: Optional[list[Quantizer]] = None, + is_qat: bool = False, + ) -> None: if quantizers is None: - quantizers = get_cadence_default_quantizers() - quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) - quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) + quantizers = get_cadence_default_quantizers(is_qat=is_qat) + a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8 + quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8)) + quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8)) super().__init__(quantizers) @@ -327,17 +384,23 @@ class CadenceFusedConvReluQuantizer(CadenceQuantizer): Quantizer using fused conv+relu patterns, and including add and cat """ - def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None: + def __init__( + self, + quantizers: Optional[list[Quantizer]] = None, + is_qat: bool = False, + ) -> None: if quantizers is None: quantizers = [] + a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8 + a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym # Order matters here, perform the "fused" patterns first - quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym)) - quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym)) - quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym)) - quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym)) - quantizers = quantizers + get_cadence_default_quantizers() - quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8)) - quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8)) + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), a8w8sym)) + quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym)) + quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym)) + quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat) + quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8)) + quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8)) super().__init__(quantizers)