diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index c1df27ecc82..c1a37c8577a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 85f8c346c9a..d4c55de9e7c 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -94,7 +94,10 @@ int main(int argc, char* argv[]) auto result = arg_parser.parse(argc, argv); if(!result) + { + arg_parser.print(); return -1; + } try { diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 462f11e4055..905d3ffc72b 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool kPadK = true; @@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool kPadK = true; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp index a1f287df6bb..2ea28ec558b 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm_config.hpp @@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp index 31d263ea1df..63fef41a403 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -26,12 +26,14 @@ void bquant_quantgrouped_bf16fp4_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; +#if !defined(CK_GFX950_SUPPORT) lut[hash_multiple_strings( {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; +#endif lut[hash_multiple_strings( {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 37117eaa0f5..24e6e77214c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -166,7 +166,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -196,7 +196,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index 7cd9889d78d..7f16a4bde08 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) if(ck_tile::is_gfx12_supported()) { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + constexpr int kKLanePerWarp = 2; + constexpr int kABK1PerLane = 8; + int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, - divisor, + kKLanePerWarp, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } else { - int divisor = 1; + int kKLanePerWarp = 1; if(ck_tile::is_gfx11_supported()) { - divisor = 1; + kKLanePerWarp = 1; } else { - assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; + kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile; } ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, gemmConfig.N_Warp_Tile, - k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp), + gemmConfig.K_Warp_Tile / kKLanePerWarp}); std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); } } @@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmC int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp; if(ck_tile::is_gfx12_supported()) { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + constexpr int kKLanePerWarp = 2; + constexpr int kABK1PerLane = 8; + int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane; ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, - divisor, + kKLanePerWarp, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7}); } else { - int divisor = 1; + int kKLanePerWarp = 1; if(ck_tile::is_gfx11_supported()) { - divisor = 1; + kKLanePerWarp = 1; } else { assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; + kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile; } ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + kKLanePerWarp, + gemmConfig.K_Warp_Tile / kKLanePerWarp}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 525a4ef9fc6..e39b02305d5 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -66,4 +66,41 @@ constexpr index_t get_k_warp_tile() #endif } +template +constexpr index_t get_k_warp_tile_for_preshuffle_b() +{ +#if CK_TILE_USE_WMMA + return 16; +#else + // When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately + // to support both dwordx4 loading instructions and MFMA instruction requirements. + // A single dwordx4 load may feed one or more MFMA instructions, or conversely, + // multiple loads may be required for a single MFMA instruction with a larger K dimension + // (e.g., 16x16x128 on gfx950). + + // To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4) + // from global memory. + const index_t kMaxBytesPerLoad = 16; // buffer load max 16 bytes + const index_t kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType); + const index_t kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile; + const index_t kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp; + + // Minimum K_Warp_Tile required by MFMA instructions + const index_t kMfmaN16Index = 0; + const index_t kMfmaN32Index = 1; +#if defined(CK_GFX950_SUPPORT) + const index_t kF8MfmaMaxK[2] = {128, 64}; + const index_t kF16MfmaMaxK[2] = {32, 16}; +#else + const index_t kF8MfmaMaxK[2] = {32, 16}; + const index_t kF16MfmaMaxK[2] = {16, 8}; +#endif + const bool kIsF8 = std::is_same_v || std::is_same_v; + const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index; + const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex]; + + return max(kKPerWarp, kMfmaMaxK); +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1ff95b157cb..d1e498361ad 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -40,20 +40,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() { using TileShape = typename Problem::BlockGemmShape; -#if defined(__gfx11__) - constexpr index_t scale = 4; -#else - constexpr index_t scale = get_warp_size() == 32 ? 2 : 1; -#endif - if constexpr(TileShape::WarpTile::at(I1) == 32) - { - return TileShape::WarpTile::at(I2) * scale / 2; - } - else - { - static_assert(TileShape::WarpTile::at(I1) == 16); - return TileShape::WarpTile::at(I2) * scale / 4; - } + + constexpr index_t k_b_per_load = + TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size(); + + return k_b_per_load; } template diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp index 6e3e95fccf1..7be44bc0244 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp @@ -1,55 +1,55 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using ABQuantGrouped = - std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for ABQuant tests -// Tuple format: -// clang-format off -using ABQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for ABQuant -TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize, ColumnMajor>, + + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor> +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp index 5247a4405de..db9f2bff441 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_padding.cpp @@ -1,39 +1,39 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using ABQuantGrouped = - std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// Type combinations for ABQuant padding padding tests -// Tuple format: -// clang-format off -using ABQuantPaddingTypes = ::testing::Types< - std::tuple ->; -// clang-format on - -// Test suite for ABQuant Padding -TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPaddingTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) -{ - this->run_test_with_validation(1024, 832, 832); -} +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant padding padding tests +// Tuple format: +// clang-format off +using ABQuantPaddingTypes = ::testing::Types< + std::tuple, GroupSize, GroupSize, ColumnMajor> +>; +// clang-format on + +// Test suite for ABQuant Padding +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPaddingTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 832, 832); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp index 793c9bd1df1..9d4d1d2c8bc 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_2d.cpp @@ -1,44 +1,44 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using ABQuantGrouped = - std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for ABQuant tests -// Tuple format: -// clang-format off -using ABQuantPreshuffleBTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for ABQuant -TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleBTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, GroupSize, GroupSize, ColumnMajor>, + std::tuple, GroupSize, GroupSize2D128N, ColumnMajor> +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleBTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp index 0e04f9fc9e9..47b06d8f9af 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp @@ -25,10 +25,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantBaseCCRTypes = ::testing::Types< // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp index da32c063042..279109c6442 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp @@ -25,10 +25,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantBaseRCRTypes = ::testing::Types< // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp index 6e90c44764c..623752b1f69 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp @@ -25,14 +25,14 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantBaseRRRCRRTypes = ::testing::Types< // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp index a7ab4120a18..b1d6652ce6b 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp @@ -24,10 +24,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using AQuantMemDecodeInterwaveTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp index 483138d7110..334ec402de9 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp @@ -24,10 +24,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using AQuantMemDecodeIntrawaveTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp index 7e851d9bd39..cf1ef115774 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp @@ -24,10 +24,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using AQuantMemPrefillInterwaveTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp index 911af678df5..ae538e0c8b8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantPrefillTypes = ::testing::Types< // RCR layout - with the Prefill BlockTile Config. - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp index 35d15f93541..072524e9866 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp @@ -25,16 +25,16 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantPreshuffleTypes = ::testing::Types< // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp index a2a4c2c38b6..2b299ce0da8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp @@ -25,8 +25,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantTransposeCTypes = ::testing::Types< // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 8c9955da749..2aaa7ac42be 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -16,6 +16,7 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" // Forward declarations for quant type-specific implementations template diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp index d491d89ef4e..606596071c0 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using BQuant1D128Types = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp index 1019caf1bca..78dfc7ffd73 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -24,13 +24,14 @@ using GroupSize64 = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant1D64Types = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64> >; // clang-format on +#if !defined(CK_GFX950_SUPPORT) // Test suite for BQuant 1D 64 TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); @@ -39,3 +40,4 @@ TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) { this->run_test_with_validation(1024, 1024, 1024); } +#endif diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp index a8b6dcd14b0..d0df5ec0804 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp @@ -24,10 +24,10 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant2DLargeNTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp index 67d52ef874c..2cdca22f65b 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp @@ -27,14 +27,14 @@ using GroupSize2D64N = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant2DMediumNTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp index 865713992d7..4c8ad303205 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp @@ -28,14 +28,14 @@ using GroupSize2D16N = ck_tile::QuantGroupShape>; // clang-format off using BQuant2DSmallNTypes = ::testing::Types< // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp index 661fd5bd336..1eb335490a1 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_1d.cpp @@ -24,8 +24,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshuffleDecode1DTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp index fb4020bcd7e..24ef4a8f0a4 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_decode_2d.cpp @@ -31,16 +31,16 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // clang-format off using BPreshuffleDecode2DTypes = ::testing::Types< // 2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp index 0d4e4d5f034..7eb860f6573 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_1d.cpp @@ -24,10 +24,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill1DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp index edc7bcaa090..5823bac894a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffleQuant_prefill_2d.cpp @@ -30,26 +30,26 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill2DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp index cf599ebbfde..f1414e8025a 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp @@ -24,8 +24,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshuffleDecode1DTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp index 66fb62e67e2..c212660ec68 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp @@ -31,16 +31,16 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // clang-format off using BPreshuffleDecode2DTypes = ::testing::Types< // 2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp index 3f6dd225d71..545823158da 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp @@ -24,10 +24,10 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill1DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp index ace07a37ae5..c1dd76b419d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp @@ -30,26 +30,26 @@ using GroupSize2D128N = ck_tile::QuantGroupShape> // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshufflePrefill2DTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D8N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D16N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D32N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N>, + std::tuple, GroupSize2D128N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp index 8a05f5812a8..e1a467635c2 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp @@ -24,9 +24,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BPreshuffleTiledPermuteTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp index 230dd8f0fc1..021bc5f26ac 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp @@ -26,20 +26,24 @@ using GroupSize2D64N = ck_tile::QuantGroupShape>; // clang-format off using BQuantTransposeTypes = ::testing::Types< // some cases with transpose layouts - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple, +#if !defined(CK_GFX950_SUPPORT) + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, +#endif + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, // pkint4 + transpose cases - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple +#if !defined(CK_GFX950_SUPPORT) + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, GroupSize64>, + std::tuple, GroupSize64>, +#endif + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, GroupSize2D64N>, + std::tuple, GroupSize2D64N> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 9652dd449d1..d837300af61 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -6,17 +6,9 @@ #include "test_gemm_quant_base.hpp" #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if CK_TILE_USE_WMMA - return 16; -#else - return is_8bit ? 64 : 32; -#endif -} - +template struct GemmConfigBase { static constexpr bool kPadM = false; @@ -50,26 +42,28 @@ struct GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; -struct GemmConfigDecode : public GemmConfigBase +template +struct GemmConfigDecode : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; }; -struct GemmConfigPrefill : public GemmConfigBase +template +struct GemmConfigPrefill : public GemmConfigBase { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; }; -struct GemmConfigPrefillIntrawave : public GemmConfigBase +template +struct GemmConfigPrefillIntrawave : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -77,7 +71,8 @@ struct GemmConfigPrefillIntrawave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -struct GemmConfigPrefillInterwave : public GemmConfigBase +template +struct GemmConfigPrefillInterwave : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -85,7 +80,8 @@ struct GemmConfigPrefillInterwave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -struct GemmConfigDecodeIntrawave : public GemmConfigBase +template +struct GemmConfigDecodeIntrawave : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; @@ -93,7 +89,8 @@ struct GemmConfigDecodeIntrawave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -struct GemmConfigDecodeInterwave : public GemmConfigBase +template +struct GemmConfigDecodeInterwave : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 16; static constexpr ck_tile::index_t N_Tile = 64; @@ -101,66 +98,89 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -struct GemmConfigMxFp4 : public GemmConfigBase +template +struct GemmConfigMxFp4 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128; }; -struct GemmConfigPreshuffleQuant : public GemmConfigBase +template +struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; }; -struct GemmConfigTransposeC : public GemmConfigBase +template +struct GemmConfigTransposeC : public GemmConfigBase { static constexpr bool TransposeC = true; }; -struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase +template +struct GemmConfigPreshuffleQuantTransposeC : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; static constexpr bool TransposeC = true; }; -struct GemmConfigPadding : public GemmConfigBase +template +struct GemmConfigPadding : public GemmConfigBase { static constexpr bool kPadN = true; static constexpr bool kPadK = true; }; -struct GemmConfigPreshuffleBDecode : public GemmConfigDecode +template +struct GemmConfigPreshuffleBDecode : public GemmConfigDecode { + using Base = GemmConfigDecode; static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; -struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode +template +struct GemmConfigPreshuffleQuantDecode : public GemmConfigDecode { static constexpr bool PreshuffleQuant = true; }; -struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill +template +struct GemmConfigPreshuffleBPrefill : public GemmConfigPrefill { + using Base = GemmConfigPrefill; static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; -struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill +template +struct GemmConfigPreshuffleQuantPrefill : public GemmConfigPrefill { static constexpr bool PreshuffleQuant = true; }; -struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill +template +struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBPrefill { - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + using Base = GemmConfigPreshuffleBPrefill; + static constexpr int N_Repeat = Base::N_Tile / Base::N_Warp_Tile / Base::N_Warp; static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; -struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode +template +struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode { + using Base = GemmConfigPreshuffleBDecode; static constexpr bool PreshuffleQuant = true; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); }; template @@ -398,7 +418,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -616,7 +636,7 @@ class TestCkTileGemmAQuantMem throw std::runtime_error("Arguments not supported for AQuant kernel"); } ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); @@ -902,7 +922,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -1209,7 +1229,7 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -1430,7 +1450,7 @@ class TestCkTileGemmRowColQuant } ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; @@ -1644,7 +1664,7 @@ class TestCkTileGemmTensorQuant } ck_tile::launch_kernel(s, - ck_tile::make_kernel( + ck_tile::make_kernel::kBlockPerCu>( Kernel{}, grids, blocks, 0, kargs)); }; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp index bb0fa218998..e9555fdd448 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_rowcol.cpp @@ -23,8 +23,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using RowColQuantTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp index 8b4c90f8b92..496ebba0fcf 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_tensor.cpp @@ -23,8 +23,8 @@ using GroupSize = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using TensorQuantTypes = ::testing::Types< - std::tuple, - std::tuple + std::tuple, GroupSize>, + std::tuple, GroupSize> >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp index f6620c105d4..77ed9f9bb66 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" using AddScale = ck_tile::element_wise::AddScale; using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd; @@ -23,28 +24,6 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if CK_TILE_USE_WMMA - return 16; -#else -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -#endif -} - template constexpr ck_tile::index_t get_k_warp_tile() @@ -86,7 +87,7 @@ struct config static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template @@ -102,7 +103,7 @@ struct config_wmma static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index e588ad2cc12..49e48af3c62 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -30,25 +30,25 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using PersistentType = typename Tuple::Persistent; static constexpr bool Persistent = PersistentType::value; - static const bool kPadM = false; - static const bool kPadN = false; - static const bool kPadK = true; // preshuffle pipeline requires k padding + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = true; // preshuffle pipeline requires k padding - static const int kBlockPerCu = Tuple::BlockPerCu_; + static constexpr int kBlockPerCu = Tuple::BlockPerCu_; // Tile dimensions from tuple - static const ck_tile::index_t M_Tile = Tuple::M_Tile_; - static const ck_tile::index_t N_Tile = Tuple::N_Tile_; - static const ck_tile::index_t K_Tile = Tuple::K_Tile_; - - static const ck_tile::index_t M_Warp = 1; - static const ck_tile::index_t N_Warp = 4; - static const ck_tile::index_t K_Warp = 1; - - static const ck_tile::index_t M_Warp_Tile = 16; - static const ck_tile::index_t N_Warp_Tile = 16; - static const ck_tile::index_t K_Warp_Tile = - ck_tile::get_k_warp_tile(); + static constexpr ck_tile::index_t M_Tile = Tuple::M_Tile_; + static constexpr ck_tile::index_t N_Tile = Tuple::N_Tile_; + static constexpr ck_tile::index_t K_Tile = Tuple::K_Tile_; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile_for_preshuffle_b(); static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem static constexpr bool TransposeC = false; // transpose c is not supported