Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ def trace(
inputs: tuple[object, ...],
dump_graphs: bool = False,
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
is_qat: bool = False,
) -> ExportedProgram:
"""
Trace the model with export and return an ExportedProgram.
"""
if ops_to_keep is None:
ops_to_keep = []
program = trace_fn(
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
model, inputs, is_qat=is_qat, strict=True, ops_to_keep=ops_to_keep
)

if dump_graphs:
Expand All @@ -77,6 +78,7 @@ def prepare_pt2(
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
dump_graphs: bool = False,
is_qat: bool = False,
) -> torch.fx.GraphModule:
"""
Trace and Prepare a model using the given quantizer.
Expand All @@ -89,10 +91,10 @@ def prepare_pt2(

ops_to_keep = quantizer.get_ops_to_preserve_from_decomposition()
traced_program = trace(
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep
model, inputs, dump_graphs=dump_graphs, ops_to_keep=ops_to_keep, is_qat=is_qat
)
prepared_program = prepare_traced_pt2(
traced_program, quantizer, dump_graphs=dump_graphs
traced_program, quantizer, dump_graphs=dump_graphs, is_qat=is_qat
)

return prepared_program
Expand All @@ -102,6 +104,7 @@ def prepare_traced_pt2(
program: ExportedProgram,
quantizer: CadenceQuantizer,
dump_graphs: bool = False,
is_qat: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare a model using the given quantizer.
Expand All @@ -112,7 +115,7 @@ def prepare_traced_pt2(
Returns a GraphModule with the prepared model.
"""

prepared_model = prepare_fn(program, quantizer, is_qat=False)
prepared_model = prepare_fn(program, quantizer, is_qat=is_qat)

if dump_graphs:
logging.info("Graph after preparation:")
Expand Down
115 changes: 89 additions & 26 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
no_outside_users,
)
from torch import fx
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
HistogramObserver,
MinMaxObserver,
)
from torchao.quantization.pt2e.quantizer import (
ComposableQuantizer,
DerivedQuantizationSpec,
Expand Down Expand Up @@ -106,6 +111,47 @@
None,
)

act_qat_qspec_asym8s = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize,
)

wgt_qat_qspec_asym8s = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver),
)

wgt_qat_qspec_sym8s = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=FakeQuantize.with_args(observer=MinMaxObserver),
)

qconfig_A8W8_qat = QuantizationConfig(
act_qat_qspec_asym8s,
act_qat_qspec_asym8s,
wgt_qat_qspec_asym8s,
None,
)

qconfig_A8W8sym_qat = QuantizationConfig(
act_qat_qspec_asym8s,
act_qat_qspec_asym8s,
wgt_qat_qspec_sym8s,
None,
)

qconfig_A16 = QuantizationConfig(
act_qspec_asym16s,
act_qspec_asym16s,
Expand Down Expand Up @@ -221,18 +267,20 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
return []


def get_cadence_default_quantizers() -> List[Quantizer]:
def get_cadence_default_quantizers(is_qat: bool = False) -> List[Quantizer]:
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
return [
CadenceAtenQuantizer(AddmmPattern(), qconfig_A8W8),
CadenceAtenQuantizer(BmmPattern(), qconfig_A8W8),
CadenceAtenQuantizer(Conv1dPattern(), qconfig_A8W8sym),
CadenceAtenQuantizer(Conv2dPattern(), qconfig_A8W8sym),
CadenceAtenQuantizer(LinearPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MatmulPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MaxPool2dPattern(), qconfig_A8W8),
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), qconfig_A8W8),
CadenceAtenQuantizer(ReluPattern0(), qconfig_A8W8),
CadenceAtenQuantizer(ReluPattern1(), qconfig_A8W8),
CadenceAtenQuantizer(AddmmPattern(), a8w8),
CadenceAtenQuantizer(BmmPattern(), a8w8),
CadenceAtenQuantizer(Conv1dPattern(), a8w8sym),
CadenceAtenQuantizer(Conv2dPattern(), a8w8sym),
CadenceAtenQuantizer(LinearPattern(), a8w8),
CadenceAtenQuantizer(MatmulPattern(), a8w8),
CadenceAtenQuantizer(MaxPool2dPattern(), a8w8),
CadenceAtenQuantizer(MaxPool2dWithoutIndicesPattern(), a8w8),
CadenceAtenQuantizer(ReluPattern0(), a8w8),
CadenceAtenQuantizer(ReluPattern1(), a8w8),
]


Expand Down Expand Up @@ -270,9 +318,13 @@ class CadenceDefaultQuantizer(CadenceQuantizer):
Default quantizer for Cadence backend.
"""

def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
def __init__(
self,
quantizers: Optional[list[Quantizer]] = None,
is_qat: bool = False,
) -> None:
if quantizers is None:
quantizers = get_cadence_default_quantizers()
quantizers = get_cadence_default_quantizers(is_qat=is_qat)
super().__init__(quantizers)


Expand Down Expand Up @@ -314,11 +366,16 @@ class CadenceWakeWordQuantizer(CadenceQuantizer):
Quantizer for WakeWord, including add and cat
"""

def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
def __init__(
self,
quantizers: Optional[list[Quantizer]] = None,
is_qat: bool = False,
) -> None:
if quantizers is None:
quantizers = get_cadence_default_quantizers()
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
quantizers = get_cadence_default_quantizers(is_qat=is_qat)
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))
super().__init__(quantizers)


Expand All @@ -327,17 +384,23 @@ class CadenceFusedConvReluQuantizer(CadenceQuantizer):
Quantizer using fused conv+relu patterns, and including add and cat
"""

def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
def __init__(
self,
quantizers: Optional[list[Quantizer]] = None,
is_qat: bool = False,
) -> None:
if quantizers is None:
quantizers = []
a8w8 = qconfig_A8W8_qat if is_qat else qconfig_A8W8
a8w8sym = qconfig_A8W8sym_qat if is_qat else qconfig_A8W8sym
# Order matters here, perform the "fused" patterns first
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), qconfig_A8W8sym))
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), qconfig_A8W8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), qconfig_A8W8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), qconfig_A8W8sym))
quantizers = quantizers + get_cadence_default_quantizers()
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8W8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), qconfig_A8W8))
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern0(), a8w8sym))
quantizers.append(CadenceAtenQuantizer(Conv1dReluPattern1(), a8w8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern0(), a8w8sym))
quantizers.append(CadenceAtenQuantizer(Conv2dReluPattern1(), a8w8sym))
quantizers = quantizers + get_cadence_default_quantizers(is_qat=is_qat)
quantizers.append(CadenceAtenQuantizer(AddPattern(), a8w8))
quantizers.append(CadenceAtenQuantizer(CatPattern(), a8w8))
super().__init__(quantizers)


Expand Down
Loading