diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 557a6f490e0..b353bb38bbd 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -743,10 +743,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." ) - edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) - - # We want to quantize (in the source transforms) the weights of the model - # in the checkpoint dtype. + # Quantize weights in checkpoint dtype for accuracy, then cast to + # dtype_override afterward. IntxUnpackedToInt8Tensor.to() properly + # propagates the dtype change to scale/zero_point/output dtype. logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}") edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( _get_source_transforms( @@ -791,9 +790,14 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager: local_global_attention=llm_config.model.local_global_attention, use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear, use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding, + quantize_with_hqq=llm_config.quantization.use_hqq, ) ) + # Now cast to the dtype override after quantization, so non-quantized + # components use the desired computation dtype. + edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) + return edge_manager @@ -1736,8 +1740,7 @@ def _get_source_transforms( # noqa get_quant_embedding_transform( embedding_quantize, use_shared_embedding, - checkpoint_dtype, - quantize_with_hqq, + quantize_with_hqq=quantize_with_hqq, ) ) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 560035f8e5b..b50bff92662 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -755,6 +755,14 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype ) + def _apply(self, fn, recurse=True): + """Override _apply to update self.dtype when the module is cast via .to(dtype).""" + super()._apply(fn, recurse) + # Probe the new dtype from the scales buffer, which gets cast by super()._apply. + if self.scales is not None: + self.dtype = self.scales.dtype + return self + ############################ Source Transform Start ####################### @@ -762,7 +770,6 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: def get_quant_embedding_transform( embedding_quantize: str, use_shared_embedding: bool = False, - dtype_override: Optional[DType] = None, quantize_with_hqq: bool = True, ): if embedding_quantize.startswith("torchao:"): @@ -817,13 +824,11 @@ def _torchao_embedding_quantizer(model): else: group_size = int(group_size) bitwidth = int(bitwidth) - torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None return lambda model: EmbeddingQuantHandler( model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth in [2, 4]), - precision=torch_dtype, quantize_with_hqq=quantize_with_hqq, ).quantized_model() diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index f4bdfbf1a0d..63ffe03a9fe 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -429,6 +429,7 @@ class QuantizationConfig: calibration_limit: Optional[int] = None calibration_seq_length: Optional[int] = None calibration_data: str = "Once upon a time" + use_hqq: bool = True def __post_init__(self): if self.qmode: