From 080fa14140700d83656558203283d155be0e3d2d Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Thu, 22 Jan 2026 12:36:40 -0500 Subject: [PATCH 1/8] [CK TILE] Add new function get_k_warp_tile_for_preshuffle_b --- example/ck_tile/03_gemm/gemm_utils.hpp | 4 ++-- .../03_gemm/gemm_weight_preshuffle.cpp | 3 +++ .../ops/gemm/pipeline/tile_gemm_shape.hpp | 20 +++++++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index c1df27ecc82..c1a37c8577a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 85f8c346c9a..d4c55de9e7c 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -94,7 +94,10 @@ int main(int argc, char* argv[]) auto result = arg_parser.parse(argc, argv); if(!result) + { + arg_parser.print(); return -1; + } try { diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 525a4ef9fc6..429522ac68f 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -66,4 +66,24 @@ constexpr index_t get_k_warp_tile() #endif } +template +constexpr index_t get_k_warp_tile_for_preshuffle_b() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(N_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + // K value is determined by the maximum bytes that can be loaded in a single instruction + // This K value is sufficient for MFMA/WMMA shapes: 16x16x16, 16x16x32, 32x32x16 + const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes + const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); + const int KLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; + return kMaxElementsPerLoad * KLanePerWarp; +#endif +} + } // namespace ck_tile From dc83e285e10ba0cdd2dfee9a1208746301d712ed Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Thu, 22 Jan 2026 17:49:39 -0500 Subject: [PATCH 2/8] [CK TILE] simplify function GetKBPerLoad --- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1ff95b157cb..e33d525e283 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -39,21 +39,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { + using BDataType = remove_cvref_t; using TileShape = typename Problem::BlockGemmShape; -#if defined(__gfx11__) - constexpr index_t scale = 4; -#else - constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; -#endif - if constexpr(TileShape::WarpTile::at(I1) == 32) - { - return TileShape::WarpTile::at(I2) * scale / 2; - } - else - { - static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) * scale / 4; - } + + constexpr index_t k_b_per_load = + TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); + + /* The k_b_per_load should meet the requirement that each thread loads 16 bytes in + * Preshuffle B */ + static_assert(k_b_per_load * sizeof(BDataType) == 16); + + return k_b_per_load; } template From 109bfa155852d6eda9aa7a0cd787a650c79b0030 Mon Sep 17 00:00:00 2001 From: Cong MA Date: Fri, 23 Jan 2026 11:00:07 -0500 Subject: [PATCH 3/8] [CK TILE] Update get_k_warp_tile_for_preshuffle_b for MI350 --- .../ops/gemm/pipeline/tile_gemm_shape.hpp | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 429522ac68f..b9382dee842 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -69,21 +69,26 @@ constexpr index_t get_k_warp_tile() template constexpr index_t get_k_warp_tile_for_preshuffle_b() { -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(N_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - // K value is determined by the maximum bytes that can be loaded in a single instruction - // This K value is sufficient for MFMA/WMMA shapes: 16x16x16, 16x16x32, 32x32x16 const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); - const int KLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; - return kMaxElementsPerLoad * KLanePerWarp; + const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; + const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; + + const index_t kMfmaN16Index = 0; + const index_t kMfmaN32Index = 1; +#if defined(CK_GFX950_SUPPORT) + const index_t kF8MfmaMaxK[2] = {128, 64}; + const index_t kF16MfmaMaxK[2] = {32, 16}; +#else + const index_t kF8MfmaMaxK[2] = {32, 16}; + const index_t kF16MfmaMaxK[2] = {16, 8}; #endif + const bool kIsF8 = + std::is_same_v || std::is_same_v; + const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index; + const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex]; + + return max(kKPerWarp, kMfmaMaxK); } } // namespace ck_tile From bc91bb7dd718ae87f1643d2ee1074970f6a78da8 Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Fri, 23 Jan 2026 18:41:35 -0500 Subject: [PATCH 4/8] [CK TILE] Apply get_k_warp_tile_for_preshuffle_b in examples and tests --- .../ck_tile/17_grouped_gemm/grouped_gemm.hpp | 7 ++-- .../quant_grouped_gemm_config.hpp | 2 +- .../38_block_scale_gemm/gemm_utils.hpp | 4 +- include/ck_tile/host/tensor_shuffle_utils.hpp | 40 +++++++++---------- .../ops/gemm/pipeline/tile_gemm_shape.hpp | 28 +++++++++---- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 5 --- .../gemm_block_scale/test_gemm_quant_base.hpp | 14 ++++--- .../test_gemm_quant_fixtures.hpp | 27 ++++--------- .../test_gemm_multi_abd_util.hpp | 23 +---------- .../test_gemm_pipeline_util.hpp | 5 ++- .../test_grouped_gemm_preshuffle_util.hpp | 2 +- 11 files changed, 68 insertions(+), 89 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 462f11e4055..905d3ffc72b 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool kPadK = true; @@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool kPadK = true; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp index a1f287df6bb..2ea28ec558b 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp @@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37fc998e5ba..37e46c6b04b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -145,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -175,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index 7cd9889d78d..7f16a4bde08 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) if(ck_tile::is_gfx12_supported()) { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + constexpr int kKLanePerWarp = 2; + constexpr int kABK1PerLane = 8; + int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, - divisor, + kKLanePerWarp, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } else { - int divisor = 1; + int kKLanePerWarp = 1; if(ck_tile::is_gfx11_supported()) { - divisor = 1; + kKLanePerWarp = 1; } else { - assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; + kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile; } ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, - k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp), + gemmConfig.K_Warp_Tile / kKLanePerWarp}); std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); } } @@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmC int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp; if(ck_tile::is_gfx12_supported()) { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + constexpr int kKLanePerWarp = 2; + constexpr int kABK1PerLane = 8; + int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, - divisor, + kKLanePerWarp, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7}); } else { - int divisor = 1; + int kKLanePerWarp = 1; if(ck_tile::is_gfx11_supported()) { - divisor = 1; + kKLanePerWarp = 1; } else { assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; + kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile; } ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + kKLanePerWarp, + gemmConfig.K_Warp_Tile / kKLanePerWarp}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index b9382dee842..63bec56e200 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -69,26 +69,38 @@ constexpr index_t get_k_warp_tile() template constexpr index_t get_k_warp_tile_for_preshuffle_b() { +#if CK_TILE_USE_WMMA + return 16; +#else + // When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately + // to support both dwordx4 loading instructions and MFMA instruction requirements. + // A single dwordx4 load may feed one or more MFMA instructions, or conversely, + // multiple loads may be required for a single MFMA instruction with a larger K dimension + // (e.g., 16x16x128 on gfx950). + + // To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4) + // from global memory. const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); - const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; - const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; + const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; + const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; + // Minimum K_Warp_Tile required by MFMA instructions const index_t kMfmaN16Index = 0; const index_t kMfmaN32Index = 1; #if defined(CK_GFX950_SUPPORT) - const index_t kF8MfmaMaxK[2] = {128, 64}; + const index_t kF8MfmaMaxK[2] = {128, 64}; const index_t kF16MfmaMaxK[2] = {32, 16}; #else - const index_t kF8MfmaMaxK[2] = {32, 16}; + const index_t kF8MfmaMaxK[2] = {32, 16}; const index_t kF16MfmaMaxK[2] = {16, 8}; #endif - const bool kIsF8 = - std::is_same_v || std::is_same_v; - const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index; - const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex]; + const bool kIsF8 = std::is_same_v || std::is_same_v; + const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index; + const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex]; return max(kKPerWarp, kMfmaMaxK); +#endif } } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index e33d525e283..d1e498361ad 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -39,16 +39,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { - using BDataType = remove_cvref_t; using TileShape = typename Problem::BlockGemmShape; constexpr index_t k_b_per_load = TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); - /* The k_b_per_load should meet the requirement that each thread loads 16 bytes in - * Preshuffle B */ - static_assert(k_b_per_load * sizeof(BDataType) == 16); - return k_b_per_load; } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 8c9955da749..4dbc122110f 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -16,6 +16,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" // Forward declarations for quant type-specific implementations template @@ -74,11 +75,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; - static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; - static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant; - static constexpr bool PreshuffleB = GemmConfig::PreshuffleB; - static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; - static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; + static constexpr ck_tile::index_t K_Warp_Tile = + GemmConfig::PreshuffleB + ? ck_tile::get_k_warp_tile_for_preshuffle_b() + : ck_tile::get_k_warp_tile(); + static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant; + static constexpr bool PreshuffleB = GemmConfig::PreshuffleB; + static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; + static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; static constexpr bool kPadM = GemmConfig::kPadM; static constexpr bool kPadN = GemmConfig::kPadN; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 79c86935efc..bee2e7ed719 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -6,16 +6,7 @@ #include "test_gemm_quant_base.hpp" #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" - -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if CK_TILE_USE_WMMA - return 16; -#else - return is_8bit ? 64 : 32; -#endif -} +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" struct GemmConfigBase { @@ -50,23 +41,21 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + // K_Warp_Tile is derived from N_Warp_Tile and BDataType }; struct GemmConfigDecode : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; }; struct GemmConfigPrefill : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; }; struct GemmConfigMxFp4 : public GemmConfigBase diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index f6620c105d4..77ed9f9bb66 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" using AddScale = ck_tile::element_wise::AddScale; using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd; @@ -23,28 +24,6 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if CK_TILE_USE_WMMA - return 16; -#else -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -#endif -} - template constexpr ck_tile::index_t get_k_warp_tile() @@ -86,7 +87,7 @@ struct config static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template @@ -102,7 +103,7 @@ struct config_wmma static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index e588ad2cc12..a490cf42f1b 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -48,7 +48,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test static const ck_tile::index_t M_Warp_Tile = 16; static const ck_tile::index_t N_Warp_Tile = 16; static const ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem static constexpr bool TransposeC = false; // transpose c is not supported From 70bd8f8143e002b8c46037b38af8000cc9e38d18 Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Mon, 26 Jan 2026 16:26:49 -0500 Subject: [PATCH 5/8] [CK TIEL] Fix a const type qualifier error --- .../test_grouped_gemm_preshuffle_util.hpp | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index a490cf42f1b..49e48af3c62 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -30,25 +30,25 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using PersistentType = typename Tuple::Persistent; static constexpr bool Persistent = PersistentType::value; - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = true; // preshuffle pipeline requires k padding + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = true; // preshuffle pipeline requires k padding - static const int kBlockPerCu = Tuple::BlockPerCu_; + static constexpr int kBlockPerCu = Tuple::BlockPerCu_; // Tile dimensions from tuple - static const ck_tile::index_t M_Tile = Tuple::M_Tile_; - static const ck_tile::index_t N_Tile = Tuple::N_Tile_; - static const ck_tile::index_t K_Tile = Tuple::K_Tile_; - - static const ck_tile::index_t M_Warp = 1; - static const ck_tile::index_t N_Warp = 4; - static const ck_tile::index_t K_Warp = 1; - - static const ck_tile::index_t M_Warp_Tile = 16; - static const ck_tile::index_t N_Warp_Tile = 16; - static const ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile_for_preshuffle_b(); + static constexpr ck_tile::index_t M_Tile = Tuple::M_Tile_; + static constexpr ck_tile::index_t N_Tile = Tuple::N_Tile_; + static constexpr ck_tile::index_t K_Tile = Tuple::K_Tile_; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem static constexpr bool TransposeC = false; // transpose c is not supported From 89d4d517b5336472c7fce472c691c5f29e6a1ce6 Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Mon, 26 Jan 2026 16:55:49 -0500 Subject: [PATCH 6/8] [CK TIEL] Fix type error --- include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 63bec56e200..e39b02305d5 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -80,10 +80,10 @@ constexpr index_t get_k_warp_tile_for_preshuffle_b() // To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4) // from global memory. - const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes - const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); - const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; - const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; + const index_t kMaxBytesPerLoad = 16; // buffer load max 16 bytes + const index_t kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); + const index_t kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; + const index_t kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; // Minimum K_Warp_Tile required by MFMA instructions const index_t kMfmaN16Index = 0; From 6ba8427812fca86831734f6e14d25a80a20afd0c Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Mon, 26 Jan 2026 22:51:02 -0500 Subject: [PATCH 7/8] [CK TILE] set proper K_Warp_Tile for quant gemm tests --- .../test_gemm_quant_abquant_base.cpp | 110 +++++++++--------- .../test_gemm_quant_abquant_padding.cpp | 78 ++++++------- .../test_gemm_quant_abquant_preshuffle_2d.cpp | 88 +++++++------- .../test_gemm_quant_aquant_base_ccr.cpp | 8 +- .../test_gemm_quant_aquant_base_rcr.cpp | 8 +- .../test_gemm_quant_aquant_base_rrr_crr.cpp | 12 +- .../test_gemm_quant_aquant_prefill.cpp | 6 +- .../test_gemm_quant_aquant_preshuffle.cpp | 16 +-- .../test_gemm_quant_aquant_transpose_c.cpp | 4 +- .../gemm_block_scale/test_gemm_quant_base.hpp | 13 +-- .../test_gemm_quant_bquant_1d_128.cpp | 6 +- .../test_gemm_quant_bquant_1d_64.cpp | 8 +- .../test_gemm_quant_bquant_2d_large_n.cpp | 8 +- .../test_gemm_quant_bquant_2d_medium_n.cpp | 16 +-- .../test_gemm_quant_bquant_2d_small_n.cpp | 16 +-- ...quant_bquant_preshuffleQuant_decode_1d.cpp | 4 +- ...quant_bquant_preshuffleQuant_decode_2d.cpp | 20 ++-- ...uant_bquant_preshuffleQuant_prefill_1d.cpp | 8 +- ...uant_bquant_preshuffleQuant_prefill_2d.cpp | 40 +++---- ...gemm_quant_bquant_preshuffle_decode_1d.cpp | 4 +- ...gemm_quant_bquant_preshuffle_decode_2d.cpp | 20 ++-- ...emm_quant_bquant_preshuffle_prefill_1d.cpp | 8 +- ...emm_quant_bquant_preshuffle_prefill_2d.cpp | 40 +++---- ..._quant_bquant_preshuffle_tiled_permute.cpp | 6 +- .../test_gemm_quant_bquant_transpose.cpp | 24 ++-- .../test_gemm_quant_fixtures.hpp | 67 +++++++---- .../test_gemm_quant_rowcol.cpp | 4 +- .../test_gemm_quant_tensor.cpp | 4 +- 28 files changed, 335 insertions(+), 311 deletions(-) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp index 6e3e95fccf1..7be44bc0244 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp @@ -1,55 +1,55 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using ABQuantGrouped = - std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for ABQuant tests -// Tuple format: -// clang-format off -using ABQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for ABQuant -TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor> +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp index 5247a4405de..db9f2bff441 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp @@ -1,39 +1,39 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using ABQuantGrouped = - std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// Type combinations for ABQuant padding padding tests -// Tuple format: -// clang-format off -using ABQuantPaddingTypes = ::testing::Types< - std::tuple ->; -// clang-format on - -// Test suite for ABQuant Padding -TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPaddingTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) -{ - this->run_test_with_validation(1024, 832, 832); -} +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant padding padding tests +// Tuple format: +// clang-format off +using ABQuantPaddingTypes = ::testing::Types< + std::tuple, GroupSize, GroupSize, ColumnMajor> +>; +// clang-format on + +// Test suite for ABQuant Padding +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPaddingTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp index 793c9bd1df1..9d4d1d2c8bc 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -1,44 +1,44 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using ABQuantGrouped = - std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for ABQuant tests -// Tuple format: -// clang-format off -using ABQuantPreshuffleBTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for ABQuant -TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleBTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor> +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp index 0e04f9fc9e9..47b06d8f9af 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp @@ -25,10 +25,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantBaseCCRTypes = ::testing::Types< // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp index da32c063042..279109c6442 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp @@ -25,10 +25,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantBaseRCRTypes = ::testing::Types< // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp index 6e90c44764c..623752b1f69 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp @@ -25,14 +25,14 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantBaseRRRCRRTypes = ::testing::Types< // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp index 133c11860ac..ff29274ee25 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantPrefillTypes = ::testing::Types< // RCR layout - with the Prefill BlockTile Config. - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp index 35d15f93541..072524e9866 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp @@ -25,16 +25,16 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantPreshuffleTypes = ::testing::Types< // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp index a2a4c2c38b6..2b299ce0da8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp @@ -25,8 +25,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantTransposeCTypes = ::testing::Types< // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 4dbc122110f..2aaa7ac42be 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -75,14 +75,11 @@ class TestCkTileGemmQuantBase : public ::testing::Test static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; - static constexpr ck_tile::index_t K_Warp_Tile = - GemmConfig::PreshuffleB - ? ck_tile::get_k_warp_tile_for_preshuffle_b() - : ck_tile::get_k_warp_tile(); - static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant; - static constexpr bool PreshuffleB = GemmConfig::PreshuffleB; - static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; - static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; + static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; + static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant; + static constexpr bool PreshuffleB = GemmConfig::PreshuffleB; + static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN; + static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; static constexpr bool kPadM = GemmConfig::kPadM; static constexpr bool kPadN = GemmConfig::kPadN; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp index d491d89ef4e..606596071c0 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using BQuant1D128Types = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp index 1019caf1bca..97640fbd0a7 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -24,10 +24,10 @@ using GroupSize64 = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant1D64Types = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp index a8b6dcd14b0..d0df5ec0804 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp @@ -24,10 +24,10 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant2DLargeNTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp index 67d52ef874c..2cdca22f65b 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp @@ -27,14 +27,14 @@ using GroupSize2D64N = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant2DMediumNTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp index 865713992d7..4c8ad303205 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp @@ -28,14 +28,14 @@ using GroupSize2D16N = ck_tile::QuantGroupShape>; // clang-format off using BQuant2DSmallNTypes = ::testing::Types< // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp index 661fd5bd336..1eb335490a1 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp @@ -24,8 +24,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshuffleDecode1DTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp index fb4020bcd7e..24ef4a8f0a4 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp @@ -31,16 +31,16 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // clang-format off using BPreshuffleDecode2DTypes = ::testing::Types< // 2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp index 0d4e4d5f034..7eb860f6573 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp @@ -24,10 +24,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill1DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp index edc7bcaa090..5823bac894a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp @@ -30,26 +30,26 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill2DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp index cf599ebbfde..f1414e8025a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp @@ -24,8 +24,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshuffleDecode1DTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp index 66fb62e67e2..c212660ec68 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp @@ -31,16 +31,16 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // clang-format off using BPreshuffleDecode2DTypes = ::testing::Types< // 2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp index 3f6dd225d71..545823158da 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp @@ -24,10 +24,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill1DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp index ace07a37ae5..c1dd76b419d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp @@ -30,26 +30,26 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill2DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp index 8a05f5812a8..e1a467635c2 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp @@ -24,9 +24,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshuffleTiledPermuteTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp index 230dd8f0fc1..6d1e0b95524 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp @@ -26,20 +26,20 @@ using GroupSize2D64N = ck_tile::QuantGroupShape>; // clang-format off using BQuantTransposeTypes = ::testing::Types< // some cases with transpose layouts - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, // pkint4 + transpose cases - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index bee2e7ed719..c3cba3bcf6b 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -8,6 +8,7 @@ #include "ck_tile/host/tensor_shuffle_utils.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" +template struct GemmConfigBase { static constexpr bool kPadM = false; @@ -41,83 +42,109 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - // K_Warp_Tile is derived from N_Warp_Tile and BDataType + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; -struct GemmConfigDecode : public GemmConfigBase +template +struct GemmConfigDecode : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; static constexpr ck_tile::index_t K_Tile = 256; }; -struct GemmConfigPrefill : public GemmConfigBase +template +struct GemmConfigPrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128; }; -struct GemmConfigMxFp4 : public GemmConfigBase +template +struct GemmConfigMxFp4 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128; }; -struct GemmConfigPreshuffleQuant : public GemmConfigBase +template +struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; }; -struct GemmConfigTransposeC : public GemmConfigBase +template +struct GemmConfigTransposeC : public GemmConfigBase { static constexpr bool TransposeC = true; }; -struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase +template +struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; static constexpr bool TransposeC = true; }; -struct GemmConfigPadding : public GemmConfigBase +template +struct GemmConfigPadding : public GemmConfigBase { static constexpr bool kPadN = true; static constexpr bool kPadK = true; }; -struct GemmConfigPreshuffleBDecode : public GemmConfigDecode +template +struct GemmConfigPreshuffleBDecode : public GemmConfigDecode { + using Base = GemmConfigDecode; static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; -struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode +template +struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode { static constexpr bool PreshuffleQuant = true; }; -struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill +template +struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill { + using Base = GemmConfigPrefill; static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; -struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill +template +struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill { static constexpr bool PreshuffleQuant = true; }; -struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill +template +struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill { - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + using Base = GemmConfigPreshuffleBPrefill; + static constexpr int N_Repeat = Base::N_Tile / Base::N_Warp_Tile / Base::N_Warp; static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; -struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode +template +struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode { + using Base = GemmConfigPreshuffleBDecode; static constexpr bool PreshuffleQuant = true; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; template @@ -355,7 +382,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -642,7 +669,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -949,7 +976,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -1170,7 +1197,7 @@ class TestCkTileGemmRowColQuant } ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -1384,7 +1411,7 @@ class TestCkTileGemmTensorQuant } ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp index bb0fa218998..e9555fdd448 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp @@ -23,8 +23,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using RowColQuantTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp index 8b4c90f8b92..496ebba0fcf 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp @@ -23,8 +23,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using TensorQuantTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on From ed0eadb8c8e9f30869e11cbe19c8005a9b6a2f7b Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Mon, 26 Jan 2026 23:11:14 -0500 Subject: [PATCH 8/8] [CK TILE] disable tests on gfx950 --- .../gemm_bquant_quantgrouped_bf16mxfp4.cpp | 2 ++ .../ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp | 2 ++ .../gemm_block_scale/test_gemm_quant_bquant_transpose.cpp | 4 ++++ 3 files changed, 8 insertions(+) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp index 31d263ea1df..63fef41a403 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -26,12 +26,14 @@ void bquant_quantgrouped_bf16fp4_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; +#if !defined(CK_GFX950_SUPPORT) lut[hash_multiple_strings( {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; +#endif lut[hash_multiple_strings( {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp index 97640fbd0a7..78dfc7ffd73 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -31,6 +31,7 @@ using BQuant1D64Types = ::testing::Types< >; // clang-format on +#if !defined(CK_GFX950_SUPPORT) // Test suite for BQuant 1D 64 TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); @@ -39,3 +40,4 @@ TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) { this->run_test_with_validation(1024, 1024, 1024); } +#endif diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp index 6d1e0b95524..021bc5f26ac 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp @@ -26,17 +26,21 @@ using GroupSize2D64N = ck_tile::QuantGroupShape>; // clang-format off using BQuantTransposeTypes = ::testing::Types< // some cases with transpose layouts +#if !defined(CK_GFX950_SUPPORT) std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, std::tuple, GroupSize64>, std::tuple, GroupSize64>, +#endif std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, std::tuple, GroupSize2D64N>, std::tuple, GroupSize2D64N>, // pkint4 + transpose cases +#if !defined(CK_GFX950_SUPPORT) std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, std::tuple, GroupSize64>, std::tuple, GroupSize64>, +#endif std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, std::tuple, GroupSize2D64N>, std::tuple, GroupSize2D64N>