Skip to content
Merged
8 changes: 8 additions & 0 deletions example/59_grouped_gemm_multi_ABD/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm

add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp)
add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8)

add_custom_target(example_grouped_gemm_wmma_multi_abd)

add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16 grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_fp16)

add_example_executable(example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8.cpp)
add_example_dependencies(example_grouped_gemm_wmma_multi_abd example_grouped_gemm_multi_abd_wmma_fixed_nk_bias_bf16_i8)
Comment thread
ErwinTerpstra marked this conversation as resolved.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,11 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK

static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_wmma_supported<ComputeType, ComputeType, MPerXDL, NPerXDL>())
{
return false;
}

// Split-K autodeduction is not supported
if(arg.k_batch_ < 1)
{
Expand Down Expand Up @@ -720,6 +725,26 @@ struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
}
}

for(index_t i = 0; i < arg.group_count_; i++)
{
if(get_warp_size() == 64)
{
if(GridwiseGemm64::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
true)
{
supported = false;
}
}
else
{
if(GridwiseGemm32::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) !=
true)
{
supported = false;
}
}
}

return supported;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
using typename Base::DsGridPointer;
using AsDataType_ = AsDataType;
using BsDataType_ = BsDataType;
using EDataType_ = EDataType;

struct Problem
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include <iostream>
#include <sstream>

#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/utility/functional4.hpp"

#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

namespace ck {
namespace tensor_operation {
namespace host {

// this function is also defined in CK but because of the way we use it in
// profile_gemm_multi_impl, it requires the arguments to not be const
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation explaining why this duplicate function definition exists and how it differs from the CK version. The comment on line 20-21 is insufficient - it should explain the const-correctness difference and the implications for usage in profile_gemm_multi_impl.

Suggested change
// this function is also defined in CK but because of the way we use it in
// profile_gemm_multi_impl, it requires the arguments to not be const
// NOTE:
// This helper intentionally duplicates `concat_tuple_of_refs` from the core CK utilities,
// but with a different const-correctness contract on its arguments:
//
// - The CK version is defined to operate on (typically) const-qualified tuples of
// references; its parameters are more permissive and can accept `const Tuple<...>&`.
// - This host-side overload is deliberately restricted to *non-const* tuples of
// references: `ck::Tuple<X&...>&` and `ck::Tuple<Y&...>&`.
//
// In `profile_gemm_multi_impl`, we need to concatenate tuples that contain non-const
// references to tensors/buffers so that:
// * The resulting concatenated tuple preserves non-const reference semantics, allowing
// the profiled kernels and host-side utilities to modify the referenced objects, and
// * Overload resolution / SFINAE continues to select APIs that require non-const
// references (these would reject a const-qualified tuple produced by the CK version).
//
// If this function were replaced by the CK version, the arguments in
// `profile_gemm_multi_impl` could become (or be treated as) const, which would either:
// - Prevent intended mutation of the underlying tensors, or
// - Cause subtle compilation or behavior differences due to const propagation.
//
// For that reason, this duplicate, non-const overload must remain local to the host-side
// GEMM multi reference implementation and should not be "simplified" by switching to the
// CK variant without carefully revisiting `profile_gemm_multi_impl` and its call sites.

Copilot uses AI. Check for mistakes.
template <typename... X, typename... Y>
auto concat_tuple_of_refs(ck::Tuple<X&...>& tx, ck::Tuple<Y&...>& ty)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might still be a good idea to place this in the util header where the other const version is also stored.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{
return ck::unpack2(
[&](auto&&... zs) { return ck::Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}

template <typename AsTensorTuple,
typename BsTensorTuple,
typename DsTensorTuple,
typename EDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AComputeType,
typename BComputeType>
struct ReferenceGemmMultiABD : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const AsTensorTuple& as_m_k,
const BsTensorTuple& bs_k_n,
const DsTensorTuple& ds_m_n,
Tensor<EDataType>& e_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: as_m_k_{as_m_k},
bs_k_n_{bs_k_n},
ds_m_n_{ds_m_n},
e_m_n_{e_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}

const AsTensorTuple& as_m_k_;
const BsTensorTuple& bs_k_n_;
const DsTensorTuple& ds_m_n_;
Tensor<EDataType>& e_m_n_;

AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};

// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemmMultiABD::Argument;

float Run(const Argument& arg)
{
static constexpr index_t NumATensor = AsTensorTuple::Size();
static constexpr index_t NumBTensor = BsTensorTuple::Size();
static constexpr index_t NumDTensor = DsTensorTuple::Size();

const int M = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[0];
const int K = arg.as_m_k_[Number<0>{}].mDesc.GetLengths()[1];
const int N = arg.bs_k_n_[Number<0>{}].mDesc.GetLengths()[1];

Tensor<AComputeType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
// result
auto data_refs1 = ck::tie(a_m_k(m, k));
// inputs
auto data_refs2 = generate_tie(
[&](auto i) -> auto& { return arg.as_m_k_[Number<i>{}](m, k); },
Number<NumATensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(arg.a_element_op_, data_refs);
}
}

Tensor<BComputeType> b_k_n({K, N});
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N; ++n)
{
// result
auto data_refs1 = ck::tie(b_k_n(k, n));
// inputs
auto data_refs2 = generate_tie(
[&](auto i) -> auto& { return arg.bs_k_n_[Number<i>{}](k, n); },
Number<NumBTensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(arg.b_element_op_, data_refs);
}
}

using PassThrough = ck::tensor_operation::element_wise::PassThrough;
Tensor<AccDataType> c_m_n({M, N});

using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<AComputeType,
BComputeType,
AccDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();

auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});

ref_invoker.Run(ref_argument);

for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
// compulsory
auto data_refs1 = ck::tie(arg.e_m_n_(m, n), c_m_n(m, n));
// optional (if multiple Ds)
auto data_refs2 = generate_tie(
[&](auto i) -> auto& { return arg.ds_m_n_[Number<i>{}](m, n); },
Number<NumDTensor>{});
auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2);
unpack(arg.cde_element_op_, data_refs);
}
}

return 0;
}

float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};

static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}

bool IsSupportedArgument(const device::BaseArgument*) override { return true; }

static auto MakeArgument(const AsTensorTuple& as_m_k,
const BsTensorTuple& bs_k_n,
const DsTensorTuple& ds_m_n,
Tensor<EDataType>& e_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{as_m_k, bs_k_n, ds_m_n, e_m_n, a_element_op, b_element_op, cde_element_op};
}

static auto MakeInvoker() { return Invoker{}; }

virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}

std::string GetTypeString() const override
{
auto str = std::stringstream();

// clang-format off
str << "ReferenceGemmMultiABD"
<< std::endl;
// clang-format on

return str.str();
}
};

} // namespace host
} // namespace tensor_operation
} // namespace ck
Loading