Conversation
Signed-off-by: Evgeny <etsykunov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces hybrid quantization allowing rowwise and columnwise GEMM operands to use different quantization formats (e.g., FP8 rowwise + NVFP4 columnwise). The core additions are Confidence Score: 4/5Safe to merge; a P1 from a prior review thread about _hybrid_split_quantize crashing on mixed-quantizer lists appears unresolved New findings are P2 only (quantize_impl ignores usage flags, wasting computation when set_usage is called). The unresolved P1 from the prior review thread—_hybrid_split_quantize AttributeError on mixed-quantizer lists—keeps the score at 4. transformer_engine/pytorch/module/grouped_linear.py (_hybrid_split_quantize mixed-quantizer safety), transformer_engine/pytorch/tensor/hybrid_tensor.py (quantize_impl usage-flag propagation) Important Files Changed
Sequence DiagramsequenceDiagram
participant C as Caller
participant HQ as HybridQuantizer
participant RQ as rowwise_quantizer
participant CQ as columnwise_quantizer
participant HT as HybridQuantizedTensor
participant G as general_gemm
participant U as _unwrap_hybrid_A/B
C->>HQ: quantize(tensor)
HQ->>RQ: quantize(tensor)
RQ-->>HQ: rowwise_result (e.g. Float8Tensor)
HQ->>CQ: quantize(tensor)
CQ-->>HQ: columnwise_result (e.g. NVFP4Tensor)
HQ-->>HT: wrap(rowwise_result, columnwise_result)
HT-->>C: HybridQuantizedTensor
C->>G: general_gemm(A=HybridQuantizedTensor, layout="TN")
G->>U: _unwrap_hybrid_A(A, "TN")
Note over U: layout[0]=="T" → rowwise_sub_storage
U-->>G: Float8Tensor
G->>U: _unwrap_hybrid_B(B, "TN")
Note over U: layout[1]=="N" → rowwise_sub_storage
U-->>G: rowwise storage
G-->>C: output tensor
Greploops — Automatically fix all review issues by running Reviews (2): Last reviewed commit: "Merge branch 'main' into hybrid_quantiza..." | Re-trigger Greptile |
| def _hybrid_split_quantize(tensor, m_splits, quantizers): | ||
| """Grouped split+quantize for HybridQuantizer lists. | ||
|
|
||
| Runs tex.split_quantize twice (once per direction with the native | ||
| sub-quantizers), then zips the results into HybridQuantizedTensorStorage. | ||
| Non-hybrid quantizers in the list fall back to per-split Python quantize. | ||
| """ | ||
| from ..tensor.storage.hybrid_tensor_storage import HybridQuantizedTensorStorage as HybridStorage | ||
|
|
||
| row_quantizers = [q.rowwise_quantizer for q in quantizers] | ||
| col_quantizers = [q.columnwise_quantizer for q in quantizers] | ||
|
|
||
| row_results = tex.split_quantize(tensor, m_splits, row_quantizers) | ||
| col_results = tex.split_quantize(tensor, m_splits, col_quantizers) | ||
|
|
||
| return [ | ||
| HybridStorage( | ||
| rowwise_storage=row, | ||
| columnwise_storage=col, | ||
| rowwise_quantizer=rq, | ||
| columnwise_quantizer=cq, | ||
| quantizer=q, | ||
| fake_dtype=tensor.dtype, |
There was a problem hiding this comment.
_hybrid_split_quantize crashes on mixed-quantizer lists
_has_hybrid_quantizer returns True if any quantizer in the list is a HybridQuantizer, but _hybrid_split_quantize unconditionally accesses q.rowwise_quantizer and q.columnwise_quantizer for every element. If the list contains even one non-hybrid quantizer, this raises AttributeError at runtime.
The docstring claims "Non-hybrid quantizers in the list fall back to per-split Python quantize", but no such fallback exists in the implementation:
row_quantizers = [q.rowwise_quantizer for q in quantizers] # crashes if q is not HybridQuantizer
col_quantizers = [q.columnwise_quantizer for q in quantizers]Either the condition at the call site should assert all-or-nothing hybrid (all(isinstance(q, HybridQuantizer) for q in quantizers if q is not None)), or the function needs to implement the per-element fallback its docstring promises. The same issue applies to all three call sites in both the forward and backward paths.
| "fake_dtype": self._dtype, | ||
| } | ||
|
|
||
| def __repr__(self): | ||
| return ( | ||
| "HybridQuantizedTensorStorage(" | ||
| f"rowwise={type(self._rowwise_storage).__name__}, " | ||
| f"columnwise={type(self._columnwise_storage).__name__}, " | ||
| f"dtype={self._dtype})" | ||
| ) |
There was a problem hiding this comment.
| def make_empty( | ||
| self, | ||
| shape: Iterable[int], | ||
| *, | ||
| dtype: torch.dtype = torch.float32, | ||
| device: Optional[torch.device] = None, | ||
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True | ||
| rowwise_empty = self.rowwise_quantizer.make_empty( | ||
| shape, | ||
| dtype=dtype, | ||
| device=device, | ||
| pin_memory=pin_memory, | ||
| ) | ||
| self.rowwise_quantizer.internal = False | ||
|
|
||
| self.columnwise_quantizer.internal = True | ||
| columnwise_empty = self.columnwise_quantizer.make_empty( | ||
| shape, | ||
| dtype=dtype, | ||
| device=device, |
timmoon10
left a comment
There was a problem hiding this comment.
Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | ||
| columnwise_result = self.columnwise_quantizer.quantize(tensor) |
There was a problem hiding this comment.
Do we handle the case where not all usages are needed? I'd expect something like:
| rowwise_result = self.rowwise_quantizer.quantize(tensor) | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) | |
| rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None | |
| columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None |
| requires_grad: bool = False, | ||
| pin_memory: bool = False, | ||
| ) -> HybridQuantizedTensor: | ||
| self.rowwise_quantizer.internal = True |
There was a problem hiding this comment.
Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.
| def set_usage( | ||
| self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None | ||
| ) -> None: | ||
| super().set_usage(rowwise=rowwise, columnwise=columnwise) | ||
|
|
There was a problem hiding this comment.
This is redundant:
| def set_usage( | |
| self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None | |
| ) -> None: | |
| super().set_usage(rowwise=rowwise, columnwise=columnwise) |
| def factory(role): | ||
| if role == "linear_weight": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_mxfp8_quantizer(), | ||
| ) | ||
| if role == "linear_input": | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_fp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| if role in ("linear_grad_output", "linear_grad_input"): | ||
| return HybridQuantizer( | ||
| rowwise_quantizer=_make_mxfp8_quantizer(), | ||
| columnwise_quantizer=_make_nvfp4_quantizer(), | ||
| ) | ||
| return None |
There was a problem hiding this comment.
This is horrifying. Good test.
Description
Hybrid quantization is functional.
C++ optimizations will come in the next PRs.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: