Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand All @@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand Down
3 changes: 3 additions & 0 deletions example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ int main(int argc, char* argv[])
auto result = arg_parser.parse(argc, argv);

if(!result)
{
arg_parser.print();
return -1;
}

try
{
Expand Down
7 changes: 4 additions & 3 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool kPadK = true;

Expand All @@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand Down Expand Up @@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool kPadK = true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ void bquant_quantgrouped_bf16fp4_instance_factory(
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#if !defined(CK_GFX950_SUPPORT)
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
#endif
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down Expand Up @@ -196,7 +196,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down
40 changes: 19 additions & 21 deletions include/ck_tile/host/tensor_shuffle_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)

if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp),
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
}

Expand Down Expand Up @@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
kKLanePerWarp,
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
Expand Down
37 changes: 37 additions & 0 deletions include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,41 @@ constexpr index_t get_k_warp_tile()
#endif
}

template <typename PrecType, index_t N_Warp_Tile>
constexpr index_t get_k_warp_tile_for_preshuffle_b()
{
#if CK_TILE_USE_WMMA
return 16;
#else
// When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately
// to support both dwordx4 loading instructions and MFMA instruction requirements.
// A single dwordx4 load may feed one or more MFMA instructions, or conversely,
// multiple loads may be required for a single MFMA instruction with a larger K dimension
// (e.g., 16x16x128 on gfx950).

// To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4)
// from global memory.
const index_t kMaxBytesPerLoad = 16; // buffer load max 16 bytes
const index_t kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType);
const index_t kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
const index_t kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp;

// Minimum K_Warp_Tile required by MFMA instructions
const index_t kMfmaN16Index = 0;
const index_t kMfmaN32Index = 1;
#if defined(CK_GFX950_SUPPORT)
const index_t kF8MfmaMaxK[2] = {128, 64};
const index_t kF16MfmaMaxK[2] = {32, 16};
#else
const index_t kF8MfmaMaxK[2] = {32, 16};
const index_t kF16MfmaMaxK[2] = {16, 8};
#endif
const bool kIsF8 = std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index;
const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex];

return max(kKPerWarp, kMfmaMaxK);
#endif
}

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) * scale / 4;
}

constexpr index_t k_b_per_load =
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();

return k_b_per_load;
Comment on lines +44 to +47
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetKBPerLoad() was simplified to WarpTile::N * WarpTile::K / get_warp_size(), but MakeBFlatDramTileDistribution() still sets KRepeatInWave = 2 on __gfx11__ and asserts TileShape::flatKPerWarp == KThdPerWave * KBPerLoad. With the new formula, that static_assert will fail on gfx11 (it effectively becomes N*K == (warp_size/2) * (N*K/warp_size)). Either fold KRepeatInWave (gfx11) into the KB-per-load calculation, or keep the previous scaling logic so the distribution invariants remain valid.

Suggested change
constexpr index_t k_b_per_load =
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();
return k_b_per_load;
constexpr index_t base_k_b_per_load =
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();
#if defined(__gfx11__)
// On gfx11, MakeBFlatDramTileDistribution() uses KRepeatInWave = 2 and asserts
// TileShape::flatKPerWarp == KThdPerWave * KBPerLoad. To keep this invariant valid,
// fold KRepeatInWave into KBPerLoad here.
return base_k_b_per_load * 2;
#else
return base_k_b_per_load;
#endif

Copilot uses AI. Check for mistakes.
}

template <typename Problem>
Expand Down
110 changes: 55 additions & 55 deletions test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_base.cpp
Original file line number Diff line number Diff line change
@@ -1,55 +1,55 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on
// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);
// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"

#include <gtest/gtest.h>
#include <memory>

#include "test_gemm_quant_fixtures.hpp"

// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;

// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;

// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantTypes = ::testing::Types<
// PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ)
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize, ColumnMajor>,

std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigBase<FP8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<ColumnMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize2D128N, ColumnMajor>,
std::tuple<RowMajor, RowMajor, RowMajor, RowMajor, BF8, BF8, float, Half, ABQuantGrouped, GemmConfigBase<BF8>, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on

// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantTypes);

// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
Loading