From e232c17dbf2fa9edec1dafdc8aa2bd68c38a233e Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 21 Jan 2026 03:30:19 +0000 Subject: [PATCH 1/6] Refactor --- .../38_block_scale_gemm/CMakeLists.txt | 1 + .../gemm_abquant_quantgrouped.cpp | 8 +-- .../gemm_aquant_quantgrouped.cpp | 8 +-- ...mm_aquant_quantgrouped_preshufflequant.cpp | 8 +-- .../gemm_bquant_quantgrouped_bf16mxfp4.cpp | 8 +-- .../gemm_bquant_quantgrouped_bf8.cpp | 8 +-- .../gemm_bquant_quantgrouped_bf8i4.cpp | 8 +-- .../gemm_bquant_quantgrouped_fp8.cpp | 8 +-- .../gemm_bquant_quantgrouped_fp8i4.cpp | 8 +-- ...mm_bquant_quantgrouped_preshuffleb_bf8.cpp | 8 +-- ..._bquant_quantgrouped_preshuffleb_bf8i4.cpp | 8 +-- ...mm_bquant_quantgrouped_preshuffleb_fp8.cpp | 8 +-- ..._bquant_quantgrouped_preshuffleb_fp8i4.cpp | 8 +-- ...rouped_preshuffleb_preshufflequant_bf8.cpp | 8 +-- ...uped_preshuffleb_preshufflequant_bf8i4.cpp | 8 +-- ...rouped_preshuffleb_preshufflequant_fp8.cpp | 8 +-- ...uped_preshuffleb_preshufflequant_fp8i4.cpp | 8 +-- ...quant_quantgrouped_preshufflequant_bf8.cpp | 8 +-- ...ant_quantgrouped_preshufflequant_bf8i4.cpp | 8 +-- ...quant_quantgrouped_preshufflequant_fp8.cpp | 8 +-- ...ant_quantgrouped_preshufflequant_fp8i4.cpp | 8 +-- .../38_block_scale_gemm/gemm_quant.cpp | 70 +------------------ .../38_block_scale_gemm/gemm_quant_rowcol.cpp | 8 +-- .../38_block_scale_gemm/gemm_quant_tensor.cpp | 8 +-- .../38_block_scale_gemm/gemm_utils.hpp | 8 +++ 25 files changed, 99 insertions(+), 156 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index ec536f72878..13cbcc8b558 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -6,6 +6,7 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp index b1cd1a52a71..153ab4845a5 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp @@ -12,9 +12,8 @@ using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill; // template // using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode; -void abquant_quantgrouped_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); lut[hash_multiple_strings({"fp8", "abquant", "non-preshuffleb", @@ -135,4 +134,5 @@ void abquant_quantgrouped_instance_factory( BQuantGroupSize, ck_tile::QuantType::ABQuantGrouped>(arg_parser); }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ad1a4e0d100..016083be74c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -10,9 +10,8 @@ using GemmConfig = GemmConfigQuantDecode; // template // using GemmConfig = GemmConfigQuantPrefill; -void aquant_quantgrouped_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings( {"fp8", "aquant", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& @@ -56,4 +55,5 @@ void aquant_quantgrouped_instance_factory( QuantGroupSize, ck_tile::QuantType::AQuantGrouped>(arg_parser); }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped_preshufflequant.cpp index 45e8c28a4ed..32e2b3d6035 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped_preshufflequant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped_preshufflequant.cpp @@ -6,9 +6,8 @@ template using GemmConfig = GemmConfigPreshuffleQuantDecode; -void aquant_quantgrouped_preshufflequant_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings( {"fp8", "aquant", "preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& @@ -52,4 +51,5 @@ void aquant_quantgrouped_preshufflequant_instance_factory( QuantGroupSize, ck_tile::QuantType::AQuantGrouped>(arg_parser); }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp index 31d263ea1df..b8eb670135a 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_bf16fp4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index 82e30e56d2d..a95c0346cf7 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_bf8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); #ifndef CK_GFX950_SUPPORT @@ -55,4 +54,5 @@ void bquant_quantgrouped_bf8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index 515e6eb0274..d2b95d32633 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_bf8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index eaf10f057c8..a8c13c1b3dd 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_fp8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); #ifndef CK_GFX950_SUPPORT @@ -55,4 +54,5 @@ void bquant_quantgrouped_fp8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index c91867534f6..6576b22c038 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_fp8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp index 7166a5647ea..e0c112e3b7c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_bf8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = @@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_bf8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp index 85599864db4..3ffcfdac694 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp index 87cb4c9d100..de7e290eeba 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_fp8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "non-preshufflequant", "1x1x128"})] = @@ -50,4 +49,5 @@ void bquant_quantgrouped_preshuffleb_fp8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp index 0cb16441a9b..d36c20e700c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_BQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp index 640757a9562..12e23ba7222 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); lut[hash_multiple_strings({"bf8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = @@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp index 575a43afd89..cb8beee11de 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp index 9e40fbaa875..edfa1443410 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); lut[hash_multiple_strings({"fp8", "bquant", "preshuffleb", "preshufflequant", "1x1x128"})] = @@ -47,4 +46,5 @@ void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp index 2552a1d1348..c83dc0a396a 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp @@ -17,9 +17,8 @@ using GemmConfig = GemmConfigPreshuffleB_PreshuffleBQuant_Prefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp index edb28236aff..2aa54fa72b2 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshufflequant_bf8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); lut[hash_multiple_strings({"bf8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = @@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_bf8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp index 59da63447ec..2ace775216d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp index 29c88001e83..aba9a146caf 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshufflequant_fp8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); lut[hash_multiple_strings({"fp8", "bquant", "non-preshuffleb", "preshufflequant", "1x1x128"})] = @@ -52,4 +51,5 @@ void bquant_quantgrouped_preshufflequant_fp8_instance_factory( using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp index f4871325575..e0e21cef139 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigPreshuffleBQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 8de58b0a309..1fbe4d7b47b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -95,51 +95,6 @@ auto gen_lut_key(const ck_tile::ArgParser& arg_parser) return hash_multiple_strings(params); } -void abquant_quantgrouped_instance_factory( - std::unordered_map>& lut); -void aquant_quantgrouped_instance_factory( - std::unordered_map>& lut); -void aquant_quantgrouped_preshufflequant_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_bf8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_bf16fp4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_bf8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshufflequant_bf8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory( - std::unordered_map>& lut); -void bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory( - std::unordered_map>& lut); -void quant_rowcol_instance_factory( - std::unordered_map>& lut); -void quant_tensor_instance_factory( - std::unordered_map>& lut); - int main(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -153,29 +108,8 @@ int main(int argc, char* argv[]) std::cout << "Device ID: " << device_id << std::endl; ck_tile::hip_check_error(hipSetDevice(device_id)); - std::unordered_map> lut; - abquant_quantgrouped_instance_factory(lut); - aquant_quantgrouped_instance_factory(lut); - aquant_quantgrouped_preshufflequant_instance_factory(lut); - bquant_quantgrouped_fp8_instance_factory(lut); - bquant_quantgrouped_bf8_instance_factory(lut); - bquant_quantgrouped_fp8i4_instance_factory(lut); - bquant_quantgrouped_bf8i4_instance_factory(lut); - bquant_quantgrouped_bf16fp4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_fp8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_bf8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_fp8i4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_bf8i4_instance_factory(lut); - bquant_quantgrouped_preshufflequant_fp8_instance_factory(lut); - bquant_quantgrouped_preshufflequant_bf8_instance_factory(lut); - bquant_quantgrouped_preshufflequant_fp8i4_instance_factory(lut); - bquant_quantgrouped_preshufflequant_bf8i4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_fp8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_bf8_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4_instance_factory(lut); - bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4_instance_factory(lut); - quant_rowcol_instance_factory(lut); - quant_tensor_instance_factory(lut); + auto& lut = get_kernel_lut(); + std::cout << "Available kernels: " << lut.size() << std::endl; auto key = gen_lut_key(arg_parser); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp index 19c02b7ae21..d450a36f84e 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_rowcol.cpp @@ -6,9 +6,8 @@ template using GemmConfig = GemmConfigQuantDecode; -void quant_rowcol_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); // NOTE: QuantGroupSize is a place holder. rowcol pipeline does not use QuantGroupSize using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "rowcol"})] = [](const ck_tile::ArgParser& arg_parser) { @@ -27,4 +26,5 @@ void quant_rowcol_instance_factory( QuantGroupSize, ck_tile::QuantType::RowColQuant>(arg_parser); }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp index 0deb3d890bc..71b193d8f42 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_tensor.cpp @@ -6,9 +6,8 @@ template using GemmConfig = GemmConfigQuantDecode; -void quant_tensor_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); // NOTE: QuantGroupSize is a place holder. tensor pipeline does not use QuantGroupSize using QuantGroupSize = ck_tile::QuantGroupShape>; lut[hash_multiple_strings({"fp8", "tensor"})] = [](const ck_tile::ArgParser& arg_parser) { @@ -27,4 +26,5 @@ void quant_tensor_instance_factory( QuantGroupSize, ck_tile::QuantType::TensorQuant>(arg_parser); }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index a95ca4862cf..d62520f1c72 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -11,6 +11,14 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm_quant.hpp" +inline auto& get_kernel_lut() +{ + // In an inline function, function-local static objects in all function definitions are shared + // across all translation units. + static std::unordered_map> lut; + return lut; +} + inline size_t hash_multiple_strings(const std::vector& inputs) { std::hash hasher; From 662bacdc82dda3fe6293ae7b89caca459e9da434 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Wed, 21 Jan 2026 03:31:58 +0000 Subject: [PATCH 2/6] Gemm quant improvement --- .../run_gemm_quant_example.inc | 16 +++++++--------- include/ck_tile/core/arch/arch.hpp | 2 +- .../gemm_universal_pipeline_ag_bg_cr_policy.hpp | 8 ++++---- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 912527c929a..d57a7f48718 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -80,10 +80,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ck_tile::BaseGemmPipelineAgBgCrMem, ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2>>>; - const ck_tile::index_t K_split = - (args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile); + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { @@ -553,8 +552,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout))); } - std::random_device rd; - std::mt19937 gen(rd()); + std::mt19937 gen(42); std::uniform_int_distribution fill_seed(0, 500); if(init_method == 0) @@ -630,7 +628,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if(init_method == 1) { std::cout << "Monotonic initialization is not supported." << std::endl; - return 0; + return -1; } else if(init_method == 2) { @@ -900,10 +898,10 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if(arg_parser.get_int("v") == 2) { std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl; - return false; + return -1; } - return pass; + return pass ? 0 : -1; } // Usage of Two-Matrix Quantization (AB-Quant) template LdsBanksWidth) ? 1 : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType)); @@ -250,7 +250,7 @@ struct UniversalGemmBasePolicy constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = max(MinLdsLayer, - get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize); constexpr index_t NBanks = get_n_lds_banks(); static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); @@ -357,7 +357,7 @@ struct UniversalGemmBasePolicy constexpr auto K0PerThreadRead = BK0 / KThreadRead; // check if we exceed all LDS banks - constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_words_per_128b(); + constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b(); constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth) ? 1 : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType)); @@ -450,7 +450,7 @@ struct UniversalGemmBasePolicy constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto NLdsLayer = max(MinLdsLayer, - get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize); + get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize); constexpr index_t NBanks = get_n_lds_banks(); static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count"); From 8a5aea40c6451f2968fe2e3e01a3e5447d063ed6 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Thu, 22 Jan 2026 08:21:11 +0000 Subject: [PATCH 3/6] Change preshuffle --- include/ck_tile/host/tensor_shuffle_utils.hpp | 42 ++++++++++--------- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 10 ++--- ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 18 ++++++-- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index 7cd9889d78d..147b033ff98 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -69,7 +69,7 @@ auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) } template -auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) +auto shuffle_b(const ck_tile::HostTensor& t, GemmConfig) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; @@ -79,36 +79,40 @@ auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; - int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, - gemmConfig.N_Warp_Tile, - k_ / gemmConfig.K_Warp_Tile, + int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); } - else + else if(ck_tile::is_gfx11_supported()) { int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; - } - ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, - gemmConfig.N_Warp_Tile, - k_ / gemmConfig.K_Warp_Tile, + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / GemmConfig::K_Warp_Tile, divisor, - gemmConfig.K_Warp_Tile / divisor}); + GemmConfig::K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } + else + { + constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = + std::min(16 / static_cast(sizeof(T)), GemmConfig::K_Warp_Tile / KLane); + + ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, + GemmConfig::N_Warp_Tile, + k_ / ItemsPerAccess, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); + } } template diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index fd94dfb6b3d..ce8eb94a1bb 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -693,13 +693,13 @@ struct QuantGemmKernel { if constexpr(PreshuffleB) { - index_t kFlatK = - GemmPipeline::flatKPerWarp * - (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; + constexpr auto warp_k = GemmPipeline::BlockGemmShape::WarpTile::at(I2); + index_t kFlatKSplit = GemmPipeline::flatKPerWarp * (k_size / warp_k); + index_t kFlatK = GemmPipeline::flatKPerWarp * (kargs.K / warp_k); + index_t kFlatN = kargs.N * kargs.K / kFlatK; return make_naive_tensor_view( b_ptr, - make_tuple(kFlatN, kFlatK), + make_tuple(kFlatN, kFlatKSplit), make_tuple(kFlatK, 1), number{}, number<1>{}); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index 80e41cad458..c35c3d95320 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -52,11 +52,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; + using BDataType = typename Problem::BDataType; constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = GetKBPerLoad(); + constexpr index_t KBPerLoad = + min(GetKBPerLoad(), 16 / static_cast(sizeof(BDataType))); #if defined(__gfx11__) constexpr index_t KRepeatInWave = 2; #else @@ -64,8 +66,8 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel #endif constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; - constexpr index_t KRepeat = 1; - static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + constexpr index_t KRepeat = GetKBPerLoad() / KBPerLoad; + static_assert(TileShape::flatKPerWarp == KRepeat * KThdPerWave * KBPerLoad, "wrong"); constexpr index_t NBPerLoad = 1; constexpr index_t NThdPerWave = 1; @@ -98,13 +100,21 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel typename Problem::ADataType, typename Problem::BDataType>; + using BDataType = typename Problem::BDataType; + constexpr auto NumAccess = + 16 / sizeof(BDataType) * numeric_traits::PackedSize == 16 + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; using WarpGemm = WarpGemmDispatcher; + Problem::TransposeC, + false, + false, + NumAccess>; // TODO : Use a custom block policy for AsBrCr using BlockGemmPolicy = From 4e24ca16d2081947f82c43e706c4d1072350a42c Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Fri, 23 Jan 2026 01:54:07 +0000 Subject: [PATCH 4/6] Fix --- ...gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index c35c3d95320..ae2a601f8a1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -100,11 +100,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel typename Problem::ADataType, typename Problem::BDataType>; - using BDataType = typename Problem::BDataType; - constexpr auto NumAccess = - 16 / sizeof(BDataType) * numeric_traits::PackedSize == 16 - ? WGAttrNumAccessEnum::Double - : WGAttrNumAccessEnum::Single; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; + using BDataType = typename Problem::BDataType; + constexpr index_t KLaneBytes = + KLane / numeric_traits::PackedSize * sizeof(BDataType); + constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); + using WarpGemm = WarpGemmDispatcher Date: Mon, 26 Jan 2026 06:44:25 +0000 Subject: [PATCH 5/6] Fix grouped gemm ut --- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 42 ++++++++++++------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 1ff95b157cb..0e3a8c2f4e7 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -151,6 +151,7 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; + using BDataType = typename Problem::BDataType; constexpr index_t kNPerBlock = TileShape::kN; constexpr index_t kKPerBlock = TileShape::kK; @@ -162,7 +163,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = GetKBPerLoad(); + constexpr index_t KBPerLoad = + min(GetKBPerLoad(), 16 / static_cast(sizeof(BDataType))); #if defined(__gfx11__) constexpr index_t KRepeatInWave = 2; #else @@ -171,7 +173,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = KIterPerWarp; - static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + constexpr index_t KAccess = GetKBPerLoad() / KBPerLoad; + static_assert(TileShape::flatKPerWarp == KAccess * KThdPerWave * KBPerLoad, "wrong"); constexpr index_t NBPerLoad = 1; constexpr index_t NThdPerWave = 1; @@ -181,16 +184,16 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; return make_static_tile_distribution( tile_distribution_encoding< - sequence, // ? - tuple, // second direction - sequence>, // first direction + sequence, // ? + tuple, // second direction + sequence>, // wave in blk, // thd in wave // // tuple, sequence<0, 1, 2>>, // which direction - tuple, sequence<1, 2, 2>>, // which index + tuple, sequence<1, 2, 3>>, // which index // - sequence<1, 2, 1, 2>, - sequence<0, 0, 3, 3>>{}); + sequence<1, 2, 1, 2, 2>, + sequence<0, 0, 3, 1, 4>>{}); } template @@ -256,13 +259,22 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>; - using WarpGemm = WarpGemmDispatcher; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; + using BDataType = typename Problem::BDataType; + constexpr index_t KLaneBytes = + KLane / numeric_traits::PackedSize * sizeof(BDataType); + constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); + using WarpGemm = WarpGemmDispatcher; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy Date: Tue, 27 Jan 2026 01:12:23 -0500 Subject: [PATCH 6/6] Fix --- .../pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 0e3a8c2f4e7..1784436f870 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -163,13 +163,13 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t WaveNum = BlockSize / WaveSize; - constexpr index_t KBPerLoad = - min(GetKBPerLoad(), 16 / static_cast(sizeof(BDataType))); #if defined(__gfx11__) constexpr index_t KRepeatInWave = 2; #else constexpr index_t KRepeatInWave = 1; #endif + constexpr index_t KBPerLoad = min( + GetKBPerLoad(), KRepeatInWave * 16 / static_cast(sizeof(BDataType))); constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim constexpr index_t KWavePerBlk = 1; constexpr index_t KRepeat = KIterPerWarp;