Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 3 commits intoNVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 3 commits intoNVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet
Copy link
Copy Markdown
Collaborator

@negvet negvet commented Mar 31, 2026

Description

Hybrid quantization is functional.
C++ optimizations will come in the next PRs.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 31, 2026

Greptile Summary

This PR introduces hybrid quantization allowing rowwise and columnwise GEMM operands to use different quantization formats (e.g., FP8 rowwise + NVFP4 columnwise). The core additions are HybridQuantizer/HybridQuantizedTensor/HybridQuantizedTensorStorage with GEMM dispatch via _unwrap_hybrid_A/_unwrap_hybrid_B helpers that extract the layout-appropriate sub-storage based on the transpose flags in the layout string.

Confidence Score: 4/5

Safe to merge; a P1 from a prior review thread about _hybrid_split_quantize crashing on mixed-quantizer lists appears unresolved

New findings are P2 only (quantize_impl ignores usage flags, wasting computation when set_usage is called). The unresolved P1 from the prior review thread—_hybrid_split_quantize AttributeError on mixed-quantizer lists—keeps the score at 4.

transformer_engine/pytorch/module/grouped_linear.py (_hybrid_split_quantize mixed-quantizer safety), transformer_engine/pytorch/tensor/hybrid_tensor.py (quantize_impl usage-flag propagation)

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/hybrid_tensor.py Introduces HybridQuantizer and HybridQuantizedTensor; quantize_impl always runs both sub-quantizers regardless of set_usage flags
transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py Correct storage composition mixin; repr renders NoneType instead of None for absent sub-storages
transformer_engine/pytorch/module/grouped_linear.py Adds _hybrid_split_quantize and hybrid branches in forward/backward; function crashes on mixed-quantizer lists (pre-existing thread)
transformer_engine/pytorch/cpp_extensions/gemm.py Adds _unwrap_hybrid_A/_unwrap_hybrid_B; correctly routes to rowwise or columnwise sub-storage based on GEMM layout flags
transformer_engine/pytorch/module/layernorm_linear.py Disables fused quantized-norm path for HybridQuantizer; minimal hybrid-specific changes
transformer_engine/pytorch/module/layernorm_mlp.py Mirrors layernorm_linear.py hybrid handling for fc1 input quantizer
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu New CUDA kernel for blockwise cast+transpose with TMA hardware acceleration on SM90+
transformer_engine/pytorch/module/base.py Adds HybridQuantizer import; no substantive logic change
transformer_engine/pytorch/tensor/init.py Exports HybridQuantizer, HybridQuantizedTensor, HybridQuantizedTensorStorage
transformer_engine/pytorch/init.py Re-exports hybrid tensor types at top-level package
tests/pytorch/test_hybrid_quantization.py Comprehensive 1600-line test suite covering construction, quantization, storage operations, and end-to-end module paths

Sequence Diagram

sequenceDiagram
    participant C as Caller
    participant HQ as HybridQuantizer
    participant RQ as rowwise_quantizer
    participant CQ as columnwise_quantizer
    participant HT as HybridQuantizedTensor
    participant G as general_gemm
    participant U as _unwrap_hybrid_A/B

    C->>HQ: quantize(tensor)
    HQ->>RQ: quantize(tensor)
    RQ-->>HQ: rowwise_result (e.g. Float8Tensor)
    HQ->>CQ: quantize(tensor)
    CQ-->>HQ: columnwise_result (e.g. NVFP4Tensor)
    HQ-->>HT: wrap(rowwise_result, columnwise_result)
    HT-->>C: HybridQuantizedTensor

    C->>G: general_gemm(A=HybridQuantizedTensor, layout="TN")
    G->>U: _unwrap_hybrid_A(A, "TN")
    Note over U: layout[0]=="T" → rowwise_sub_storage
    U-->>G: Float8Tensor
    G->>U: _unwrap_hybrid_B(B, "TN")
    Note over U: layout[1]=="N" → rowwise_sub_storage
    U-->>G: rowwise storage
    G-->>C: output tensor
Loading

Greploops — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.
Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

Reviews (2): Last reviewed commit: "Merge branch 'main' into hybrid_quantiza..." | Re-trigger Greptile

Comment on lines +64 to +86
def _hybrid_split_quantize(tensor, m_splits, quantizers):
"""Grouped split+quantize for HybridQuantizer lists.

Runs tex.split_quantize twice (once per direction with the native
sub-quantizers), then zips the results into HybridQuantizedTensorStorage.
Non-hybrid quantizers in the list fall back to per-split Python quantize.
"""
from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage as HybridStorage

row_quantizers = [q.rowwise_quantizer for q in quantizers]
col_quantizers = [q.columnwise_quantizer for q in quantizers]

row_results = tex.split_quantize(tensor, m_splits, row_quantizers)
col_results = tex.split_quantize(tensor, m_splits, col_quantizers)

return [
HybridStorage(
rowwise_storage=row,
columnwise_storage=col,
rowwise_quantizer=rq,
columnwise_quantizer=cq,
quantizer=q,
fake_dtype=tensor.dtype,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _hybrid_split_quantize crashes on mixed-quantizer lists

_has_hybrid_quantizer returns True if any quantizer in the list is a HybridQuantizer, but _hybrid_split_quantize unconditionally accesses q.rowwise_quantizer and q.columnwise_quantizer for every element. If the list contains even one non-hybrid quantizer, this raises AttributeError at runtime.

The docstring claims "Non-hybrid quantizers in the list fall back to per-split Python quantize", but no such fallback exists in the implementation:

row_quantizers = [q.rowwise_quantizer for q in quantizers]  # crashes if q is not HybridQuantizer
col_quantizers = [q.columnwise_quantizer for q in quantizers]

Either the condition at the call site should assert all-or-nothing hybrid (all(isinstance(q, HybridQuantizer) for q in quantizers if q is not None)), or the function needs to implement the per-element fallback its docstring promises. The same issue applies to all three call sites in both the forward and backward paths.

Comment on lines +146 to +155
"fake_dtype": self._dtype,
}

def __repr__(self):
return (
"HybridQuantizedTensorStorage("
f"rowwise={type(self._rowwise_storage).__name__}, "
f"columnwise={type(self._columnwise_storage).__name__}, "
f"dtype={self._dtype})"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __repr__ shows NoneType instead of None for missing sub-storages

When _rowwise_storage or _columnwise_storage is None, type(None).__name__ produces the string "NoneType" rather than "None". HybridQuantizedTensor.__repr__ already handles this correctly with an explicit is not None guard.

Comment on lines +75 to +97
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True
rowwise_empty = self.rowwise_quantizer.make_empty(
shape,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
self.rowwise_quantizer.internal = False

self.columnwise_quantizer.internal = True
columnwise_empty = self.columnwise_quantizer.make_empty(
shape,
dtype=dtype,
device=device,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 make_empty leaves sub-quantizer internal flag set on exception

If make_empty raises, the internal = False reset is skipped and the sub-quantizer is permanently left with internal=True. Consider using a try/finally block for both sub-quantizer flag resets.

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Comment on lines +114 to +118
def set_usage(
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
super().set_usage(rowwise=rowwise, columnwise=columnwise)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is redundant:

Suggested change
def set_usage(
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
super().set_usage(rowwise=rowwise, columnwise=columnwise)

Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants