Conversation
|
Hi @sarthak-amd Kernel wise it should be exactly similar except how the block is being casted. So we can just put this part in a separate function: TransformerEngine/transformer_engine/pytorch/triton_kernels/cast_transpose.py Lines 562 to 571 in ef83316 and alternate the calls based on whether it's a MXFP8 kernel or MXFP4 kernel. The following operation can be replaced with the exp2f_rcp_triton function TransformerEngine/transformer_engine/pytorch/triton_kernels/cast_transpose.py Lines 549 to 551 in ef83316 Another difference seems to be that MXFP8 kernels doesn't have the shuffle op fused in it. Do we want that mxfp8 as well? or is that specific to fp4 data? Also @wangye805 noticed Let me know what you think @sarthak-amd |
- Updated `test_cast_mxfp4.py` to simplify quantization output handling by removing unnecessary output tensor creation. - Introduced `MXFP4BlockScaling` recipe class. - Enhanced `MXFP4Quantizer` to utilize new scaling methods and updated tensor creation logic. - Added new quantization kernel `_mxfp4_quantize_32x32_block` to remove redundant work. - Updated Triton kernel wrapper.
- Updated `mxfp4_quantize_cpu` to include a `SHUFFLE` parameter for conditional scale shuffling. - Modified tests in `test_cast_mxfp4.py` to accommodate the new shuffling logic and added parameterization for shuffle options. - Bug fix, removed redundant QuantizedTensorBase
New changes
|
- Removed redundant _empty_tensor function from utils.py. - Ensured proper newline at the end of the file in quantized_tensor.py.
- Changed `fp8_format` to `fp4_format` for consistency with the new scaling method.
| compare_fp4_data_nibblewise( | ||
| quantized_out._rowwise_data.view(torch.uint8), | ||
| ref_data, | ||
| msg=f"Rowwise FP4 ({shape}, {in_dtype})", | ||
| max_mismatch_rate=0.05, | ||
| ) | ||
| y1_scales_triton = quantized_out._rowwise_scale_inv.view(torch.uint8) | ||
| y1_scales_torch = ref_scale | ||
| if shuffle_B_matrix_for_aiter: | ||
| y1_scales_triton = un_shuffle_scales( | ||
| y1_scales_triton.view(y1_scales_triton.shape[0] // 32, -1) | ||
| ) | ||
| y1_scales_torch = un_shuffle_scales( | ||
| y1_scales_torch.view(y1_scales_torch.shape[0] // 32, -1) | ||
| ) | ||
|
|
||
| compare_e8m0_scales( |
There was a problem hiding this comment.
Okay, currently the target vs ref comparison is not coupled: we validated quantized data content and scale independently and allow for certain mismatch rate.
A better way is to do the validation jointly by adjusting the quanized value if the scale is mismatched:
TransformerEngine/tests/cpp/operator/test_cast_mxfp8.cu
Lines 314 to 353 in 9b31283
Description
Implements the MXFP4
rowwiseandcolumnwiseFP32/BF16 -> MXFP4 fused quantization + cast kernelVerify Tolerances and functional Unit Tests
The triton
te_cast_transpose_mxfp4_tritoncurrently outputs FP4 data in linear layout [M, N/2] with contiguous byte packing. AITER'sgemm_a4w4requires the B matrix in MFMA shuffle layout for tensor cores. This layout shuffle can be fused into the triton kernel in future.