Skip to content
4 changes: 2 additions & 2 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand All @@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand Down
3 changes: 3 additions & 0 deletions example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ int main(int argc, char* argv[])
auto result = arg_parser.parse(argc, argv);

if(!result)
{
arg_parser.print();
return -1;
}

try
{
Expand Down
7 changes: 4 additions & 3 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool kPadK = true;

Expand All @@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand Down Expand Up @@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool kPadK = true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ void bquant_quantgrouped_bf16fp4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#if !defined(CK_GFX950_SUPPORT)
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#endif
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down Expand Up @@ -196,7 +196,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down
40 changes: 19 additions & 21 deletions include/ck_tile/host/tensor_shuffle_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)

if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp),
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
}

Expand Down Expand Up @@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
kKLanePerWarp,
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
Expand Down
37 changes: 37 additions & 0 deletions include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,41 @@ constexpr index_t get_k_warp_tile()
#endif
}

template <typename PrecType, index_t N_Warp_Tile>
constexpr index_t get_k_warp_tile_for_preshuffle_b()
{
#if CK_TILE_USE_WMMA
return 16;
#else
// When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately
// to support both dwordx4 loading instructions and MFMA instruction requirements.
// A single dwordx4 load may feed one or more MFMA instructions, or conversely,
// multiple loads may be required for a single MFMA instruction with a larger K dimension
// (e.g., 16x16x128 on gfx950).

// To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4)
// from global memory.
const index_t kMaxBytesPerLoad = 16; // buffer load max 16 bytes
const index_t kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType);
const index_t kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
const index_t kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp;

// Minimum K_Warp_Tile required by MFMA instructions
const index_t kMfmaN16Index = 0;
const index_t kMfmaN32Index = 1;
#if defined(CK_GFX950_SUPPORT)
const index_t kF8MfmaMaxK[2] = {128, 64};
const index_t kF16MfmaMaxK[2] = {32, 16};
#else
const index_t kF8MfmaMaxK[2] = {32, 16};
const index_t kF16MfmaMaxK[2] = {16, 8};
#endif
const bool kIsF8 = std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index;
const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex];

return max(kKPerWarp, kMfmaMaxK);
#endif
}

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) * scale / 4;
}

constexpr index_t k_b_per_load =
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();

return k_b_per_load;
}

template <typename Problem>
Expand Down
110 changes: 55 additions & 55 deletions test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp
Original file line number Diff line number Diff line change
@@ -1,55 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"

#include <gtest/gtest.h>
#include <memory>

#include "test_gemm_quant_fixtures.hpp"

// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;

// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;

// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize, ColumnMajor>,

std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on

// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);

// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
Loading