Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,10 +743,9 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
)

edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())

# We want to quantize (in the source transforms) the weights of the model
# in the checkpoint dtype.
# Quantize weights in checkpoint dtype for accuracy, then cast to
# dtype_override afterward. IntxUnpackedToInt8Tensor.to() properly
# propagates the dtype change to scale/zero_point/output dtype.
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
_get_source_transforms(
Expand Down Expand Up @@ -791,9 +790,14 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
local_global_attention=llm_config.model.local_global_attention,
use_torchao_kernels_linear=llm_config.backend.torchao.use_torchao_kernels_linear,
use_torchao_kernels_tied_embedding=llm_config.backend.torchao.use_torchao_kernels_tied_embedding,
quantize_with_hqq=llm_config.quantization.use_hqq,
)
)

# Now cast to the dtype override after quantization, so non-quantized
# components use the desired computation dtype.
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())

return edge_manager


Expand Down Expand Up @@ -1736,8 +1740,7 @@ def _get_source_transforms( # noqa
get_quant_embedding_transform(
embedding_quantize,
use_shared_embedding,
checkpoint_dtype,
quantize_with_hqq,
quantize_with_hqq=quantize_with_hqq,
)
)

Expand Down
11 changes: 8 additions & 3 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,14 +755,21 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
self.weight, self.scales, None, -8, 7, indices, dtype=self.dtype
)

def _apply(self, fn, recurse=True):
"""Override _apply to update self.dtype when the module is cast via .to(dtype)."""
super()._apply(fn, recurse)
# Probe the new dtype from the scales buffer, which gets cast by super()._apply.
if self.scales is not None:
self.dtype = self.scales.dtype
return self


############################ Source Transform Start #######################


def get_quant_embedding_transform(
embedding_quantize: str,
use_shared_embedding: bool = False,
dtype_override: Optional[DType] = None,
quantize_with_hqq: bool = True,
):
if embedding_quantize.startswith("torchao:"):
Expand Down Expand Up @@ -817,13 +824,11 @@ def _torchao_embedding_quantizer(model):
else:
group_size = int(group_size)
bitwidth = int(bitwidth)
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
return lambda model: EmbeddingQuantHandler(
model,
bitwidth=bitwidth,
group_size=group_size,
packed=(bitwidth in [2, 4]),
precision=torch_dtype,
quantize_with_hqq=quantize_with_hqq,
).quantized_model()

Expand Down
1 change: 1 addition & 0 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ class QuantizationConfig:
calibration_limit: Optional[int] = None
calibration_seq_length: Optional[int] = None
calibration_data: str = "Once upon a time"
use_hqq: bool = True

def __post_init__(self):
if self.qmode:
Expand Down
Loading