-
Notifications
You must be signed in to change notification settings - Fork 288
Implement device grouped gemm fixed nk multi abd for rdna4 #3619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
b637af1
4b32353
ab6faa7
cf0ebfb
97db9d6
9a10230
06dc334
96d5a96
06bac70
15b26a3
a84d37b
55c6b7a
9e5e86e
4d00d42
ec6c772
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,204 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // SPDX-License-Identifier: MIT | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #pragma once | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <iostream> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <sstream> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "ck/tensor_operation/gpu/device/device_base.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "ck/library/utility/host_tensor.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "ck/utility/functional4.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace ck { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace tensor_operation { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| namespace host { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // this function is also defined in CK but because of the way we use it in | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // profile_gemm_multi_impl, it requires the arguments to not be const | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // this function is also defined in CK but because of the way we use it in | |
| // profile_gemm_multi_impl, it requires the arguments to not be const | |
| // NOTE: | |
| // This helper intentionally duplicates `concat_tuple_of_refs` from the core CK utilities, | |
| // but with a different const-correctness contract on its arguments: | |
| // | |
| // - The CK version is defined to operate on (typically) const-qualified tuples of | |
| // references; its parameters are more permissive and can accept `const Tuple<...>&`. | |
| // - This host-side overload is deliberately restricted to *non-const* tuples of | |
| // references: `ck::Tuple<X&...>&` and `ck::Tuple<Y&...>&`. | |
| // | |
| // In `profile_gemm_multi_impl`, we need to concatenate tuples that contain non-const | |
| // references to tensors/buffers so that: | |
| // * The resulting concatenated tuple preserves non-const reference semantics, allowing | |
| // the profiled kernels and host-side utilities to modify the referenced objects, and | |
| // * Overload resolution / SFINAE continues to select APIs that require non-const | |
| // references (these would reject a const-qualified tuple produced by the CK version). | |
| // | |
| // If this function were replaced by the CK version, the arguments in | |
| // `profile_gemm_multi_impl` could become (or be treated as) const, which would either: | |
| // - Prevent intended mutation of the underlying tensors, or | |
| // - Cause subtle compilation or behavior differences due to const propagation. | |
| // | |
| // For that reason, this duplicate, non-const overload must remain local to the host-side | |
| // GEMM multi reference implementation and should not be "simplified" by switching to the | |
| // CK variant without carefully revisiting `profile_gemm_multi_impl` and its call sites. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might still be a good idea to place this in the util header where the other const version is also stored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Uh oh!
There was an error while loading. Please reload this page.