Skip to content

Add FP8 Support For CK Tile Group GEMM#475

Draft
aris134 wants to merge 62 commits intodevfrom
amartin/ck-grouped-gemm-fp8
Draft

Add FP8 Support For CK Tile Group GEMM#475
aris134 wants to merge 62 commits intodevfrom
amartin/ck-grouped-gemm-fp8

Conversation

@aris134
Copy link

@aris134 aris134 commented Mar 6, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

TODO:

  • Add support for other architectures (i.e., MI350X)
  • Add support for other quantization modes
  • Performance analysis and tuning

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Enables mixed precision (fp8/bf8 FNUZ variants) support for CK tile grouped GEMM with tensor quantization

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

matthiasdiener and others added 28 commits February 5, 2026 17:22
Align GemmRowColTensorQuantPipelineProblem with ck_tile V3 requirements by using AccType for intermediate C results. Specific to TensorQuant (per-tensor scaling); limited to e4m3/e5m2 FNUZ formats. Updates test_numerics.py to exercise FP8 inputs in the grouped linear accuracy suite.
Enable mixed FP8/BF8 grouped GEMM for the CK backend used by GroupedLinear backward.

Certain mixed-type combinations normalize to (AType=bf8_t, BType=fp8_t), but CK currently lacks a corresponding warp GEMM specialization for WarpGemmMfma_f32_32x32x32_bf8_fp8. This prevents the default FP8 tile configuration (K_Warp_Tile=32) from compiling or dispatching correctly.

To address this, a fallback tile policy is introduced that routes the (bf8_t, fp8_t) case to a supported kernel configuration using K_Warp_Tile=16. This preserves correct GEMM operand ordering and avoids unsafe operand-swapping workarounds.

Notes:
- Only tensor quantization mode is currently supported.
- Implementation targets MI300X (CDNA3) FP8/BF8 kernels.
- Additional kernel coverage may be required for MI350X (CDNA4).

With this change, mixed FP8/BF8 backprop paths are supported and all parametrized unit tests in test_grouped_linear_accuracy_cutlass() pass successfully.
@aris134 aris134 self-assigned this Mar 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants