Skip to content

MXFP4 Cast Transpose Triton [WIP]#422

Open
sarthak-amd wants to merge 18 commits intodevfrom
feature/cast-transpose-mxfp4
Open

MXFP4 Cast Transpose Triton [WIP]#422
sarthak-amd wants to merge 18 commits intodevfrom
feature/cast-transpose-mxfp4

Conversation

@sarthak-amd
Copy link
Collaborator

@sarthak-amd sarthak-amd commented Jan 20, 2026

Description

Implements the MXFP4 rowwise and columnwise FP32/BF16 -> MXFP4 fused quantization + cast kernel

  • Verify Tolerances and functional Unit Tests

  • The triton te_cast_transpose_mxfp4_triton currently outputs FP4 data in linear layout [M, N/2] with contiguous byte packing. AITER's gemm_a4w4 requires the B matrix in MFMA shuffle layout for tensor cores. This layout shuffle can be fused into the triton kernel in future.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You

@wangye805 wangye805 requested a review from sudhu2k February 12, 2026 22:40
@sudhu2k
Copy link
Contributor

sudhu2k commented Feb 16, 2026

Hi @sarthak-amd
Can't we merge the mxfp8 triton kernel with the mxfp4 triton kernel?

Kernel wise it should be exactly similar except how the block is being casted.
in mxfp8 we use a separate function to cast the values.

def float_to_e8m0_triton(val: tl.float32) -> tl.uint8:

So we can just put this part in a separate function:
# Nearest-neighbor quantization to E2M1 values
# E2M1 representable values: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}
idx_row = tl.zeros([MXFP4_BLOCK_SIZE, MXFP4_BLOCK_SIZE], dtype=tl.uint8)
idx_row = tl.where(abs_qx_row >= 0.25, 1, idx_row) # → 0.5
idx_row = tl.where(abs_qx_row >= 0.75, 2, idx_row) # → 1.0
idx_row = tl.where(abs_qx_row >= 1.25, 3, idx_row) # → 1.5
idx_row = tl.where(abs_qx_row >= 1.75, 4, idx_row) # → 2.0
idx_row = tl.where(abs_qx_row >= 2.5, 5, idx_row) # → 3.0
idx_row = tl.where(abs_qx_row >= 3.5, 6, idx_row) # → 4.0
idx_row = tl.where(abs_qx_row >= 5.0, 7, idx_row) # → 6.0

and alternate the calls based on whether it's a MXFP8 kernel or MXFP4 kernel.

The following operation can be replaced with the exp2f_rcp_triton function

scale_unbiased_row = tl.log2(tl.maximum(amax_rounded, 1e-45)).floor() - 2
scale_unbiased_row = tl.clamp(scale_unbiased_row, min=-127.0, max=127.0)
quant_scale_row = tl.exp2(-scale_unbiased_row)

Another difference seems to be that MXFP8 kernels doesn't have the shuffle op fused in it. Do we want that mxfp8 as well? or is that specific to fp4 data?

Also @wangye805 noticed
For the column-wise quantization of an input tensor MxN, the mxfp8 seems to still have columnwise output as shape MxN but mxfp4 implementation is NxM. Technically we can add something like STORE_TRANSPOSE, to the kernel and invert the strides in the mxfp8 kernel itself.

Let me know what you think @sarthak-amd

sudhu2k added 4 commits March 4, 2026 16:42
- Updated `test_cast_mxfp4.py` to simplify quantization output handling by removing unnecessary output tensor creation.
- Introduced `MXFP4BlockScaling` recipe class.
- Enhanced `MXFP4Quantizer` to utilize new scaling methods and updated tensor creation logic.
- Added new quantization kernel `_mxfp4_quantize_32x32_block` to remove redundant work.
- Updated Triton kernel wrapper.
- Updated `mxfp4_quantize_cpu` to include a `SHUFFLE` parameter for conditional scale shuffling.
- Modified tests in `test_cast_mxfp4.py` to accommodate the new shuffling logic and added parameterization for shuffle options.
- Bug fix, removed redundant QuantizedTensorBase
@sudhu2k
Copy link
Contributor

sudhu2k commented Mar 4, 2026

New changes

  1. Refactor of kernel wrappers and mxfp4 tensors to be consistent with other recipes.
  2. Decomposition of repeated work into a separate function for the mxfp4 quantize kernel.
  3. Added shuffle tests.
  4. Bug fixes.

@sudhu2k sudhu2k marked this pull request as ready for review March 4, 2026 22:18
@sudhu2k sudhu2k requested a review from wangye805 March 4, 2026 22:18
sudhu2k added 2 commits March 4, 2026 22:19
- Removed redundant _empty_tensor function from utils.py.
- Ensured proper newline at the end of the file in quantized_tensor.py.
- Changed `fp8_format` to `fp4_format` for consistency with the new scaling method.
@sudhu2k sudhu2k self-assigned this Mar 5, 2026
@sudhu2k sudhu2k removed their request for review March 5, 2026 16:59
Comment on lines +305 to +321
compare_fp4_data_nibblewise(
quantized_out._rowwise_data.view(torch.uint8),
ref_data,
msg=f"Rowwise FP4 ({shape}, {in_dtype})",
max_mismatch_rate=0.05,
)
y1_scales_triton = quantized_out._rowwise_scale_inv.view(torch.uint8)
y1_scales_torch = ref_scale
if shuffle_B_matrix_for_aiter:
y1_scales_triton = un_shuffle_scales(
y1_scales_triton.view(y1_scales_triton.shape[0] // 32, -1)
)
y1_scales_torch = un_shuffle_scales(
y1_scales_torch.view(y1_scales_torch.shape[0] // 32, -1)
)

compare_e8m0_scales(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, currently the target vs ref comparison is not coupled: we validated quantized data content and scale independently and allow for certain mismatch rate.

A better way is to do the validation jointly by adjusting the quanized value if the scale is mismatched:

#ifdef __HIP_PLATFORM_AMD__
const double abs_tolerable_mismatches_limit = 1.0;
const double rel_tolerable_mismatches_limit = 1.0e-4;
#else
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;
#endif
std::vector<size_t> mismatches_scales_indices;
size_t mismatches_scales = 0;
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales_indices, mismatches_scales,
scale_diff_abs_tolerance,
abs_tolerable_mismatches_limit,
rel_tolerable_mismatches_limit);
#ifdef __HIP_PLATFORM_AMD__
if (::testing::Test::HasFatalFailure()) return;
adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr,
ref_output_scales.get(), scales_stride, rows, cols, rowwise,
ref_output_c.get(), otype);
mismatches_scales = 0;
#endif
const size_t mismatches_elts = 32 * mismatches_scales;
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts);
if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
rtol_dbias *= sqrt(static_cast<double>(rows)) ;
} else {
rtol_dbias *= 4;
}
compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants