Skip to content
4 changes: 2 additions & 2 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand All @@ -280,7 +280,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand Down
3 changes: 3 additions & 0 deletions example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ int main(int argc, char* argv[])
auto result = arg_parser.parse(argc, argv);

if(!result)
{
arg_parser.print();
return -1;
}

try
{
Expand Down
7 changes: 4 additions & 3 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool kPadK = true;

Expand All @@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
Expand Down Expand Up @@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool kPadK = true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down Expand Up @@ -175,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
Expand Down
40 changes: 19 additions & 21 deletions include/ck_tile/host/tensor_shuffle_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)

if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp),
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
}

Expand Down Expand Up @@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
kKLanePerWarp,
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}
Expand Down
37 changes: 37 additions & 0 deletions include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,41 @@ constexpr index_t get_k_warp_tile()
#endif
}

template <typename PrecType, index_t N_Warp_Tile>
constexpr index_t get_k_warp_tile_for_preshuffle_b()
{
#if CK_TILE_USE_WMMA
return 16;
#else
// When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately
// to support both dwordx4 loading instructions and MFMA instruction requirements.
// A single dwordx4 load may feed one or more MFMA instructions, or conversely,
// multiple loads may be required for a single MFMA instruction with a larger K dimension
// (e.g., 16x16x128 on gfx950).

// To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4)
// from global memory.
const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes
const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType);
const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp;

// Minimum K_Warp_Tile required by MFMA instructions
const index_t kMfmaN16Index = 0;
const index_t kMfmaN32Index = 1;
#if defined(CK_GFX950_SUPPORT)
const index_t kF8MfmaMaxK[2] = {128, 64};
const index_t kF16MfmaMaxK[2] = {32, 16};
#else
const index_t kF8MfmaMaxK[2] = {32, 16};
const index_t kF16MfmaMaxK[2] = {16, 8};
#endif
const bool kIsF8 = std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index;
const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex];

return max(kKPerWarp, kMfmaMaxK);
#endif
Comment thread
CongMa13 marked this conversation as resolved.
Outdated
}

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) * scale / 4;
}

constexpr index_t k_b_per_load =
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();

return k_b_per_load;
Comment thread
CongMa13 marked this conversation as resolved.
}

template <typename Problem>
Expand Down
14 changes: 9 additions & 5 deletions test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"

// Forward declarations for quant type-specific implementations
template <ck_tile::QuantType QT>
Expand Down Expand Up @@ -74,11 +75,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test

static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
static constexpr ck_tile::index_t K_Warp_Tile =
GemmConfig::PreshuffleB
? ck_tile::get_k_warp_tile_for_preshuffle_b<BDataType, N_Warp_Tile>()
: ck_tile::get_k_warp_tile<BDataType, N_Warp_Tile>();
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;

static constexpr bool kPadM = GemmConfig::kPadM;
static constexpr bool kPadN = GemmConfig::kPadN;
Expand Down
27 changes: 8 additions & 19 deletions test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,7 @@
#include "test_gemm_quant_base.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"

template <bool is_8bit>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
return is_8bit ? 64 : 32;
#endif
}
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"

struct GemmConfigBase
{
Expand Down Expand Up @@ -50,23 +41,21 @@ struct GemmConfigBase

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
// K_Warp_Tile is derived from N_Warp_Tile and BDataType
};

struct GemmConfigDecode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
};

struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
};

struct GemmConfigMxFp4 : public GemmConfigBase
Expand Down
23 changes: 1 addition & 22 deletions test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"

Comment thread
CongMa13 marked this conversation as resolved.
using AddScale = ck_tile::element_wise::AddScale;
using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd;
Expand All @@ -23,28 +24,6 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}

template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
#endif
}

template <typename A0DataType,
typename B0DataType,
typename AccDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"

template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
Expand Down Expand Up @@ -86,7 +87,7 @@ struct config

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, N_Warp_Tile>();
};

template <typename Datatype>
Expand All @@ -102,7 +103,7 @@ struct config_wmma

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, N_Warp_Tile>();
};

template <typename Tuple>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
Comment thread
CongMa13 marked this conversation as resolved.
Outdated

static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
static constexpr bool TransposeC = false; // transpose c is not supported
Expand Down