diff --git a/experimental/builder/include/ck_tile/builder/reflect/README.md b/experimental/builder/include/ck_tile/builder/reflect/README.md index 8bb9c89c800..43192b7c487 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/README.md +++ b/experimental/builder/include/ck_tile/builder/reflect/README.md @@ -9,6 +9,7 @@ See the [main builder documentation](../README.md) for an overview. The reflection system works by extracting properties from a convolution kernel *type* and formatting them into a string. This is useful for debugging, performance tuning, and generating documentation. 1. **Trait Extraction**: The `ConvTraits` template (in `conv_traits.hpp`) is specialized for each kernel instance. It extracts low-level details like tile sizes, data layouts, and pipeline versions from the kernel's type definition. +This template is common for xld and wmma, fwd and backwards weight kernels. std::optional is used for parameters that are only used by some kernels 2. **Description Generation**: The `describe()` function (in `conv_description.hpp`) uses `ConvTraits` to populate a `ConvDescription` (`Description`) object. @@ -48,6 +49,15 @@ The reflection system (`ckr::describe`) currently supports the following convolu - **Standard XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle`) - **Large Tensor XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor`) - **V3 XDL Forward Convolution** (`DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3`) +- **V3 WMMA Forward Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`) +- **XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`) +- **V3 XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffleV3`) +- **XDL Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle`) +- **Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle`) +- **V3 Two Stage XDL Backward Weight Convolution** (`DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3`) +- **Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffle`) +- **V3 Wmma Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Wmma_CShuffleV3`) +- **V3 Wmma Multiple D Backward Weight Convolution** (`DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3`) These variants all share similar template parameter structures and are compatible with the current `ConvTraits` implementation. @@ -59,15 +69,6 @@ The following instance types are **not yet supported** by the reflection system: - Uses different internal structure with parameters like `K0PerBlock`, `K1`, `M1PerThread`, etc. - Missing standard members like `kKPerBlock`, `kMPerXDL`, `kAK1` -- **WMMA Variants** (`DeviceGroupedConvFwdMultipleD_Wmma_CShuffle`) - - Uses WMMA-specific parameters like `MPerWmma`, `NPerWmma`, `MRepeat`, `NRepeat` - - Different tile transfer structure incompatible with current `ConvTraits` - -- **Backward Weight Convolution** (`DeviceGroupedConvBwdWeight_Xdl_CShuffle`) - - Uses different layout naming: `InLayout`, `WeiLayout`, `OutLayout` instead of `ALayout`, `BLayout`, `ELayout` - - Different specialization type: `ConvBackwardWeightSpecialization` vs `ConvForwardSpecialization` - - Missing several members expected by forward convolution traits - ### Future Work To support these additional instance types, the reflection system would need: diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp index 359b12c4a30..27e7dfb3629 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_describe.hpp @@ -29,30 +29,7 @@ conv::ConvDescription describe() const auto traits = conv::instance_to_conv_traits(); return conv::ConvDescription( - conv::ConvSignatureInfo{ - .spatial_dim = traits.spatial_dim, - .direction = traits.direction, - .input_layout = traits.layout[0], - .weight_layout = traits.layout[1], - .output_layout = traits.layout[2], - .data_type = traits.data_type, - .input_element_op = traits.input_element_op, - .weight_element_op = traits.weight_element_op, - .output_element_op = traits.output_element_op, - }, - conv::GemmAlgorithmInfo{ - .thread_block_size = traits.thread_block_size, - .tile_dims = traits.tile_dims, - .warp_gemm = traits.warp_gemm, - .a_tile_transfer = traits.a_tile_transfer, - .b_tile_transfer = traits.b_tile_transfer, - .c_tile_transfer = traits.c_tile_transfer, - .pipeline_version = traits.pipeline_version, - .pipeline_scheduler = traits.pipeline_scheduler, - .conv_specialization = traits.conv_specialization, - .padding = traits.gemm_padding, - }, - []() { return reflect::instance_string(); }); + traits, []() { return reflect::instance_string(); }); } } // namespace ck_tile::reflect diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index a7b6c60a73e..5c09e4b735b 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -29,44 +29,12 @@ #include #include #include +#include namespace ck_tile::reflect { namespace conv { -/// @brief Signature information for a convolution operation -/// Contains high-level properties that define the convolution's interface, -/// including dimensionality, data layout, data types, and elementwise operations. -struct ConvSignatureInfo -{ - int spatial_dim; - builder::ConvDirection direction; - builder::TensorLayout input_layout; - builder::TensorLayout weight_layout; - builder::TensorLayout output_layout; - builder::DataType data_type; - builder::ElementwiseOperation input_element_op; - builder::ElementwiseOperation weight_element_op; - builder::ElementwiseOperation output_element_op; -}; - -/// @brief Algorithm configuration for a convolution kernel -/// Contains low-level implementation details including thread block configuration, -/// tile dimensions, memory access patterns, and pipeline settings. -struct GemmAlgorithmInfo -{ - int thread_block_size; - DataTileInfo tile_dims; - WarpGemmParams warp_gemm; - InputTileTransferInfo a_tile_transfer; - InputTileTransferInfo b_tile_transfer; - OutputTileTransferInfo c_tile_transfer; - builder::PipelineVersion pipeline_version; - builder::PipelineScheduler pipeline_scheduler; - builder::ConvSpecialization conv_specialization; - builder::GemmPadding padding; -}; - /// @brief Provides human-readable descriptions of convolution kernel instances /// Generates formatted text descriptions at various levels of detail for /// understanding and documenting convolution kernel configurations. @@ -74,16 +42,12 @@ class ConvDescription : public Description { public: /// @brief Constructor for ConvDescription - /// @param sig The signature information containing high-level convolution properties - /// @param algo The algorithm configuration containing low-level implementation details + /// @param traits The ConvTraits object containing all relevant signature and algorithm + /// information /// @param instance_string_getter A callable that returns a string representation of the /// instance - ConvDescription(ConvSignatureInfo sig, - GemmAlgorithmInfo algo, - std::function instance_string_getter) - : signature_(std::move(sig)), - algorithm_(std::move(algo)), - instance_string_getter_(std::move(instance_string_getter)) + ConvDescription(ConvTraits traits, std::function instance_string_getter) + : traits_(std::move(traits)), instance_string_getter_(std::move(instance_string_getter)) { } @@ -92,7 +56,7 @@ class ConvDescription : public Description std::string brief() const override { std::ostringstream oss; - oss << signature_.spatial_dim << "D " << signature_.direction << " convolution"; + oss << traits_.spatial_dim << "D " << traits_.direction << " convolution"; return oss.str(); } @@ -101,39 +65,42 @@ class ConvDescription : public Description std::string detailed() const override { TreeFormatter f; - f.writeLine(0, signature_.spatial_dim, "D ", signature_.direction, " Convolution Kernel"); + f.writeLine(0, traits_.spatial_dim, "D ", traits_.direction, " Convolution Kernel"); f.writeLine(1, "Signature"); - f.writeLine(2, "Tensor Type: ", signature_.data_type); - f.writeLine(2, "Input Layout: ", signature_.input_layout); - f.writeLine(2, "Weight Layout: ", signature_.weight_layout); - f.writeLine(2, "Output Layout: ", signature_.output_layout); - f.writeLine(2, "Input elementwise operation: ", signature_.input_element_op); - f.writeLine(2, "Weights elementwise operation: ", signature_.weight_element_op); - f.writeLast(2, "Output elementwise operation: ", signature_.output_element_op); + f.writeLine(2, "Tensor Type: ", traits_.data_type); + f.writeLine(2, "Input Layout: ", traits_.layout[0]); + f.writeLine(2, "Weight Layout: ", traits_.layout[1]); + f.writeLine(2, "Output Layout: ", traits_.layout[2]); + f.writeLine(2, "Input elementwise operation: ", traits_.input_element_op); + f.writeLine(2, "Weights elementwise operation: ", traits_.weight_element_op); + f.writeLast(2, "Output elementwise operation: ", traits_.output_element_op); f.writeLast(1, "Algorithm"); // Compute Block section - f.writeLine(2, "Thread block size: ", algorithm_.thread_block_size); + f.writeLine(2, "Thread block size: ", traits_.thread_block_size); f.writeLine(2, "Data tile size: ", - algorithm_.tile_dims.m, + traits_.tile_dims.m, "×", - algorithm_.tile_dims.n, + traits_.tile_dims.n, "×", - algorithm_.tile_dims.k); - f.writeLine(2, "Gemm padding: ", algorithm_.padding); - f.writeLine(2, "Convolution specialization: ", algorithm_.conv_specialization); + traits_.tile_dims.k); + if(traits_.gemm_padding) + f.writeLine( + 2, "Gemm padding: ", traits_.gemm_padding.value_or(builder::GemmPadding::DEFAULT)); + else + f.writeLine(2, "Struct does not contain optional gemm_padding argument"); + f.writeLine(2, "Convolution specialization: ", traits_.conv_specialization); // Pipeline section - f.writeLine(2, "Pipeline version: ", algorithm_.pipeline_version); - f.writeLine(2, "Pipeline scheduler: ", algorithm_.pipeline_scheduler); + f.writeLine(2, "Pipeline version: ", traits_.pipeline_version); + f.writeLine(2, "Pipeline scheduler: ", traits_.pipeline_scheduler); f.writeLine(2, "Warp Gemm parameters: "); - f.writeLine( - 3, "subtile size: ", algorithm_.warp_gemm.gemm_m, "×", algorithm_.warp_gemm.gemm_n); + f.writeLine(3, "subtile size: ", traits_.warp_gemm.gemm_m, "×", traits_.warp_gemm.gemm_n); f.writeLast(3, "Number of warp gemm iterations: ", - algorithm_.warp_gemm.m_iter, + traits_.warp_gemm.m_iter, "×", - algorithm_.warp_gemm.n_iter); + traits_.warp_gemm.n_iter); // Memory Access section f.writeLast(2, "Memory access:"); @@ -141,99 +108,126 @@ class ConvDescription : public Description f.writeLine(3, "A Tile transfer: "); f.writeLine(4, "Tile dimensions: ", - algorithm_.a_tile_transfer.tile_dimensions.k0, + traits_.a_tile_transfer.tile_dimensions.k0, "×", - algorithm_.a_tile_transfer.tile_dimensions.m_or_n, + traits_.a_tile_transfer.tile_dimensions.m_or_n, "×", - algorithm_.a_tile_transfer.tile_dimensions.k1, + traits_.a_tile_transfer.tile_dimensions.k1, "×"); - f.writeLine(4, - "The innermost K subdimension size: ", - algorithm_.a_tile_transfer.transfer_params.k1); + f.writeLine( + 4, "The innermost K subdimension size: ", traits_.a_tile_transfer.transfer_params.k1); f.writeLine(4, "Spatial thread distribution over the data tile: ", - algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[0], + traits_.a_tile_transfer.transfer_params.thread_cluster_order[0], "×", - algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[1], + traits_.a_tile_transfer.transfer_params.thread_cluster_order[1], "×", - algorithm_.a_tile_transfer.transfer_params.thread_cluster_order[2]); + traits_.a_tile_transfer.transfer_params.thread_cluster_order[2]); f.writeLine(4, "The order of accessing data tile axes: ", - algorithm_.a_tile_transfer.transfer_params.src_access_order[0], + traits_.a_tile_transfer.transfer_params.src_access_order[0], "×", - algorithm_.a_tile_transfer.transfer_params.src_access_order[1], + traits_.a_tile_transfer.transfer_params.src_access_order[1], "×", - algorithm_.a_tile_transfer.transfer_params.src_access_order[2]); + traits_.a_tile_transfer.transfer_params.src_access_order[2]); f.writeLine(4, "Vectorized memory access axis index (with contiguous memory): ", - algorithm_.a_tile_transfer.transfer_params.src_vector_dim); + traits_.a_tile_transfer.transfer_params.src_vector_dim); f.writeLine(4, "Vector access (GMEM read) instruction size: ", - algorithm_.a_tile_transfer.transfer_params.src_scalar_per_vector); + traits_.a_tile_transfer.transfer_params.src_scalar_per_vector); f.writeLine(4, "Vector access (LDS write) instruction size: ", - algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); f.writeLast(4, "LDS data layout padding (to prevent bank conflicts): ", - algorithm_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + traits_.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1); f.writeLine(3, "B Tile transfer: "); f.writeLine(4, "Tile dimensions: ", - algorithm_.b_tile_transfer.tile_dimensions.k0, + traits_.b_tile_transfer.tile_dimensions.k0, "×", - algorithm_.b_tile_transfer.tile_dimensions.m_or_n, + traits_.b_tile_transfer.tile_dimensions.m_or_n, "×", - algorithm_.b_tile_transfer.tile_dimensions.k1, + traits_.b_tile_transfer.tile_dimensions.k1, "×"); - f.writeLine(4, - "The innermost K subdimension size: ", - algorithm_.b_tile_transfer.transfer_params.k1); + f.writeLine( + 4, "The innermost K subdimension size: ", traits_.b_tile_transfer.transfer_params.k1); f.writeLine(4, "Spatial thread distribution over the data tile: ", - algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[0], + traits_.b_tile_transfer.transfer_params.thread_cluster_order[0], "×", - algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[1], + traits_.b_tile_transfer.transfer_params.thread_cluster_order[1], "×", - algorithm_.b_tile_transfer.transfer_params.thread_cluster_order[2]); + traits_.b_tile_transfer.transfer_params.thread_cluster_order[2]); f.writeLine(4, "The order of accessing data tile axes: ", - algorithm_.b_tile_transfer.transfer_params.src_access_order[0], + traits_.b_tile_transfer.transfer_params.src_access_order[0], "×", - algorithm_.b_tile_transfer.transfer_params.src_access_order[1], + traits_.b_tile_transfer.transfer_params.src_access_order[1], "×", - algorithm_.b_tile_transfer.transfer_params.src_access_order[2]); + traits_.b_tile_transfer.transfer_params.src_access_order[2]); f.writeLine(4, "Vectorized memory access axis index (with contiguous memory): ", - algorithm_.b_tile_transfer.transfer_params.src_vector_dim); + traits_.b_tile_transfer.transfer_params.src_vector_dim); f.writeLine(4, "Vector access (GMEM read) instruction size: ", - algorithm_.b_tile_transfer.transfer_params.src_scalar_per_vector); + traits_.b_tile_transfer.transfer_params.src_scalar_per_vector); f.writeLine(4, "Vector access (LDS write) instruction size: ", - algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); f.writeLast(4, "LDS data layout padding (to prevent bank conflicts): ", - algorithm_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); + traits_.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1); f.writeLast(3, "C Tile transfer: "); f.writeLine(4, "Data shuffle (number of gemm instructions per iteration): ", - algorithm_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, + traits_.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, "×", - algorithm_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); + traits_.c_tile_transfer.shuffle_params.n_gemms_per_shuffle); f.writeLine(4, "Spatial thread distribution used to store data: ", - algorithm_.c_tile_transfer.thread_cluster_dims[0], + traits_.c_tile_transfer.thread_cluster_dims[0], "×", - algorithm_.c_tile_transfer.thread_cluster_dims[1], + traits_.c_tile_transfer.thread_cluster_dims[1], "×", - algorithm_.c_tile_transfer.thread_cluster_dims[2], + traits_.c_tile_transfer.thread_cluster_dims[2], "×", - algorithm_.c_tile_transfer.thread_cluster_dims[3]); - f.writeLast(4, + traits_.c_tile_transfer.thread_cluster_dims[3]); + f.writeLine(4, "Vector access (GMEM write) instruction size: ", - algorithm_.c_tile_transfer.scalar_per_vector); + traits_.c_tile_transfer.scalar_per_vector); + if(traits_.num_gemm_k_prefetch_stage) + f.writeLine( + 2, "Num gemm k prefetch stage: ", traits_.num_gemm_k_prefetch_stage.value_or(0)); + else + f.writeLine(2, + "Struct does not contain optional " + "num_gemm_k_prefetch_stage parameter"); + + if(traits_.max_transpose_transfer_src_scalar_per_vector) + f.writeLine(2, + "Max Transpose transfer scr scalar per vector: ", + traits_.max_transpose_transfer_src_scalar_per_vector.value_or(0)); + else + f.writeLine(2, + "Struct does not contain optional " + "max_transpose_transfer_src_scalar_per_vector parameter"); + if(traits_.max_transpose_dst_scalar_per_vector) + f.writeLine(2, + "Max Transpose dst scalar per vector: ", + traits_.max_transpose_dst_scalar_per_vector.value_or(0)); + else + f.writeLine( + 2, + "Struct does not contain optional max_transpose_dst_scalar_per_vector parameter"); + if(traits_.num_groups_to_merge) + f.writeLast(2, "Num groups to merge: ", traits_.num_groups_to_merge.value_or(0)); + else + f.writeLast(2, "Struct does not contain optional num_groups_to_merge parameter"); + return f.getString(); } @@ -242,8 +236,7 @@ class ConvDescription : public Description std::string instance_string() const override { return instance_string_getter_(); } private: - ConvSignatureInfo signature_; - GemmAlgorithmInfo algorithm_; + ConvTraits traits_; std::function instance_string_getter_; }; diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 451a74be342..16a9c47f7eb 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -88,7 +88,7 @@ struct ConvTraits builder::ElementwiseOperation weight_element_op; builder::ElementwiseOperation output_element_op; - builder::GemmPadding gemm_padding; + std::optional gemm_padding = std::nullopt; builder::ConvSpecialization conv_specialization; // --- Algorithm Information --- @@ -102,8 +102,14 @@ struct ConvTraits OutputTileTransferInfo c_tile_transfer; + std::optional num_gemm_k_prefetch_stage = std::nullopt; + builder::PipelineVersion pipeline_version; builder::PipelineScheduler pipeline_scheduler; + + std::optional max_transpose_transfer_src_scalar_per_vector = std::nullopt; + std::optional max_transpose_dst_scalar_per_vector = std::nullopt; + std::optional num_groups_to_merge = std::nullopt; }; } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..f052a9701bc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 00000000000..2f7c68458f9 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..4f39b00b5cc --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kABK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kABK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100644 index 00000000000..5666233091e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdTwoStage_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kTransposeTransferDstScalarPerVector, + .num_groups_to_merge = InstTraits::kNumGroupsToMerge, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 00000000000..470a10d0317 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp new file mode 100644 index 00000000000..13625aa1822 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -0,0 +1,50 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kKPerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kKPerBlock), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp new file mode 100644 index 00000000000..39fde332178 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -0,0 +1,56 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + .max_transpose_transfer_src_scalar_per_vector = + InstTraits::kMaxTransposeTransferSrcScalarPerVector, + .max_transpose_dst_scalar_per_vector = InstTraits::kMaxTransposeTransferDstScalarPerVector, + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp new file mode 100644 index 00000000000..de986455143 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvBwdWeight_Xdl_V3_CShuffle_Tag +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = bwd_wei_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(InstTraits::kK0PerBlock), + .a_tile_transfer = + conv_traits_a_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .b_tile_transfer = + conv_traits_b_transfer_params(InstTraits::kK1, InstTraits::kK0PerBlock), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = + {.shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector_NWaveNPerXdl}, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index cdd238f36a1..2f5d84a4a80 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 28c43c342fc..2108c790548 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 00000000000..9413107df7e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/builder/reflect/conv_traits.hpp" +#include "ck_tile/builder/reflect/conv_traits_helpers.hpp" +#include "ck_tile/builder/reflect/instance_traits.hpp" +#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +namespace ck_tile::reflect::conv { + +/// @brief Tag dispatch implementation for DeviceGroupedConvFwdMultipleABD_Wmma_CShuffle +template + requires HasInstanceTraits && + std::same_as::device_kernel_tag, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag> +constexpr ConvTraits instance_to_conv_traits() +{ + using InstTraits = InstanceTraits; + + return ConvTraits{ + .spatial_dim = InstTraits::kSpatialDim, + .direction = conv_direction(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), + .input_element_op = elementwise_op(), + .weight_element_op = elementwise_op(), + .output_element_op = elementwise_op(), + .gemm_padding = gemm_spec(), + .conv_specialization = conv_spec(), + .thread_block_size = InstTraits::kBlockSize, + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_a_transfer_params(InstTraits::kK1), + .b_tile_transfer = conv_traits_b_transfer_params(InstTraits::kK1), + .warp_gemm = conv_traits_wmma_warp_gemm_params(), + .c_tile_transfer = conv_traits_wmma_c_tile_transfer(), + .num_gemm_k_prefetch_stage = InstTraits::kNumGemmKPrefetchStage, + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), + }; +} + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index c4bed850ebc..0cce3bf5130 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -24,8 +24,8 @@ constexpr ConvTraits instance_to_conv_traits() return ConvTraits{ .spatial_dim = InstTraits::kSpatialDim, .direction = conv_direction(), - .layout = conv_layout(), - .data_type = conv_data_type(), + .layout = fwd_conv_layout(), + .data_type = conv_data_type(), .input_element_op = elementwise_op(), .weight_element_op = elementwise_op(), .output_element_op = elementwise_op(), diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index 46c196e95ad..4baf2423ee1 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -80,6 +80,22 @@ namespace ck_tile::reflect::conv { // SECTION 1: ENUM CONVERSIONS // ============================================================================ +// Forward convolution layout concept - checks for A/B/E layout types +template +concept HasFwdConvLayouts = requires { + typename T::ALayout; + typename T::BLayout; + typename T::ELayout; +}; + +// Backwards weight layout concept - checks for In, wei and out layouts +template +concept HasBwdWeiLayouts = requires { + typename T::InLayout; + typename T::WeiLayout; + typename T::OutLayout; +}; + /// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. /// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. /// @return The corresponding builder::PipelineVersion enum value. @@ -322,12 +338,25 @@ constexpr builder::ConvSpecialization conv_spec() // Tensor Layouts // ---------------------------------------------------------------------------- +// Helper variable template to check if CK layout enums match +template +inline constexpr bool layouts_are = + std::is_same_v && std::is_same_v && std::is_same_v; + /// @brief Helper function to report unsupported layout combinations with a clear error message. /// @details This consteval function uses throw (not static_assert) to ensure the error is not /// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. +/// @details This consteval function is designed to fail at compile time with a descriptive +/// error message when an unsupported layout combination is encountered. template [[noreturn]] consteval void report_unsupported_layout_error() { + // This will produce a compile-time error with the exception message throw "Unsupported convolution layout combination detected!\n" "The combination of ALayout, BLayout, and ELayout template parameters\n" "is not recognized for the given spatial dimension.\n" @@ -335,111 +364,99 @@ template "Check the conv_layout() function for the list of supported layout combinations."; } -/// @brief Derives the grouped convolution layout from a device kernel Instance type. -/// @tparam Instance The device kernel instance type. -/// @return An std::array containing the layouts for: -/// - [0] Input tensor layout -/// - [1] Weight tensor layout -/// - [2] Output tensor layout -/// @details This function examines the Instance's ALayout, BLayout, and ELayout types -/// along with the spatial dimension to determine the appropriate layout configuration. -/// -/// Supported layout combinations vary by spatial dimension (1D, 2D, 3D convolutions). -/// Common patterns include GNHWC (grouped, batch, spatial, channels) and variants. -/// -/// @note Compilation will fail with a clear error message if the layout combination -/// is not supported for the given spatial dimension. -/// -/// TODO: If we don't check for supported layouts, this function can be simplified. -template -constexpr std::array conv_layout() +template +constexpr auto conv_layout() { - using InstTraits = InstanceTraits; - using A = typename InstTraits::ALayout; - using B = typename InstTraits::BLayout; - using E = typename InstTraits::ELayout; - namespace ctl = ck::tensor_layout::convolution; - using enum builder::TensorLayout; - - // Helper to check if layouts match expected types - constexpr auto layouts_match = []() { - return std::is_same_v && std::is_same_v && std::is_same_v; - }; - // Helper to construct layout array - constexpr auto make_layouts = [](auto in, auto weight, auto out) { - return std::array{in, weight, out}; - }; + // Helper lambda to construct layout array + auto layouts = [](auto... Ls) { return std::array{Ls...}; }; - constexpr int spatial_dim = InstTraits::kSpatialDim; + namespace ctl = ck::tensor_layout::convolution; + using enum builder::TensorLayout; - if constexpr(spatial_dim == 1) - { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNWC, GKXC, GNWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNWC, GKXC, GNWK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NWGC, GKXC, NWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCW, GKXC, NGKW); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCW, GKCX, NGKW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNWC, GKXC, GNWK); // Unreachable - } - } - else if constexpr(spatial_dim == 2) + switch(kSpatialDim) { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNHWC, GKYXC, GNHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNHWC, GKYXC, GNHWK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NHWGC, GKYXC, NHWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NHWGC, GKYXC, NHWGK); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCHW, GKYXC, NGKHW); - else if constexpr(layouts_match.template operator()()) - return make_layouts(NGCHW, GKCYX, NGKHW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable - } - } - else if constexpr(spatial_dim == 3) - { - if constexpr(layouts_match.template operator()()) - return make_layouts(GNDHWC, GKZYXC, GNDHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(GNDHWC, GKZYXC, GNDHWK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NDHWGC, GKZYXC, NDHWGK); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NGCDHW, GKZYXC, NGKDHW); - else if constexpr(layouts_match - .template operator()()) - return make_layouts(NGCDHW, GKCZYX, NGKDHW); - else - { - report_unsupported_layout_error(); - return make_layouts(GNDHWC, GKZYXC, GNDHWK); // Unreachable - } - } - else - { - report_unsupported_layout_error(); - return make_layouts(GNHWC, GKYXC, GNHWK); // Unreachable + case 1: + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(GNWC, GKXC, GNWK); + if constexpr(layouts_are) + return layouts(NWGC, GKXC, NWGK); + if constexpr(layouts_are) + return layouts(NGCW, GKXC, NGKW); + if constexpr(layouts_are) + return layouts(NGCW, GKCX, NGKW); + break; + case 2: + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(GNHWC, GKYXC, GNHWK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NHWGC, GKYXC, NHWGK); + if constexpr(layouts_are) + return layouts(NGCHW, GKYXC, NGKHW); + if constexpr(layouts_are) + return layouts(NGCHW, GKCYX, NGKHW); + break; + case 3: + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(GNDHWC, GKZYXC, GNDHWK); + if constexpr(layouts_are) + return layouts(NDHWGC, GKZYXC, NDHWGK); + if constexpr(layouts_are) + return layouts(NGCDHW, GKZYXC, NGKDHW); + if constexpr(layouts_are) + return layouts(NGCDHW, GKCZYX, NGKDHW); + break; } + + // If we reach here, the layout combination is not supported + // Call consteval function to trigger a compile-time error with a clear message + report_unsupported_layout_error(); + + // This return is unreachable but needed to satisfy the compiler + return layouts(GNHWC, GKYXC, GNHWK); +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout + +template +constexpr auto fwd_conv_layout() + requires HasFwdConvLayouts> +{ + + using A = typename InstanceTraits::ALayout; + using B = typename InstanceTraits::BLayout; + using E = typename InstanceTraits::ELayout; + return conv_layout::kSpatialDim>(); +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return An std::array corresponding to the tensor layouts: +/// index 0 -> Input layout +/// index 1 -> Weight layout +/// index 2 -> Output layout +template +constexpr auto bwd_wei_conv_layout() + requires HasBwdWeiLayouts> +{ + + using A = typename InstanceTraits::InLayout; + using B = typename InstanceTraits::WeiLayout; + using E = typename InstanceTraits::OutLayout; + return conv_layout::kSpatialDim>(); } // ---------------------------------------------------------------------------- @@ -447,13 +464,11 @@ constexpr std::array conv_layout() // ---------------------------------------------------------------------------- /// @brief Helper function to report unsupported data type with a clear error message. -/// @details This consteval function uses throw (not static_assert) to ensure the error is not -/// silently ignored during SFINAE. The thrown string becomes part of the compiler error message. -template +template [[noreturn]] consteval void report_unsupported_data_type_error() { throw "Unsupported data type detected!\n" - "The ADataType is not recognized.\n" + "The DataTypeFromInstance is not recognized.\n" "Supported types are: ck::half_t (FP16), ck::Tuple (FP16_FP16), " "ck::bhalf_t (BF16), ck::Tuple (BF16_BF16), float (FP32), " "ck::Tuple (FP32_FP32), double (FP64), ck::f8_t (FP8), ck::bf8_fnuz_t " @@ -462,62 +477,44 @@ template "Please verify that your kernel instance uses a supported data type."; } -/// @brief Derives the data type from a device kernel Instance type. -/// @tparam Instance The device kernel instance type. -/// @return A builder::DataType enum value representing the input data type. -/// @details This function examines the Instance's ADataType to determine the data type -/// used for the input tensor. The function supports various floating-point and integer -/// types, including tuple types for mixed-precision operations. -/// -/// Supported data types include: -/// - FP16 (ck::half_t) -/// - FP16_FP16 (ck::Tuple) -/// - BF16 (ck::bhalf_t) -/// - BF16_BF16 (ck::Tuple) -/// - FP32 (float) -/// - FP32_FP32 (ck::Tuple) -/// - FP64 (double) -/// - FP8 (ck::f8_t) -/// - BF8 (ck::bf8_fnuz_t, ck::bf8_ocp_t) -/// - I8 (int8_t) -/// - I8_I8 (ck::Tuple) -/// - U8 (uint8_t) -template +/// @brief Derives the data type from a device kernel `Instance` type. +/// Returns a `builder::DataType` enum value (e.g., FP16, BF16, FP32, BF8). +// Note: maybe move to types.hpp? +template constexpr builder::DataType conv_data_type() + { - using InstTraits = InstanceTraits; - using ADataType = typename InstTraits::ADataType; using enum builder::DataType; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return FP16; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return FP16_FP16; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF16; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return BF16_BF16; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP32; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return FP32_FP32; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP64; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return FP8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return BF8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return I8; - else if constexpr(std::is_same_v>) + else if constexpr(std::is_same_v>) return I8_I8; - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) return U8; else { - report_unsupported_data_type_error(); + report_unsupported_data_type_error(); return FP32; // Unreachable } } @@ -736,4 +733,92 @@ constexpr builder::PipelineScheduler get_pipeline_scheduler() } } +// ============================================================================ +// SECTION 4: Helper functions for common structures often used in reflection +// ============================================================================ + +template +constexpr DataTileInfo conv_traits_data_tile(int k_or_k0 = InstTraits::kKPerBlock) +{ + return DataTileInfo{.m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = k_or_k0}; +} + +template +constexpr InputTileTransferInfo +conv_traits_a_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock) +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kMPerBlock, .k1 = _k1}, + .transfer_params = {.k1 = _k1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; +} + +template +constexpr InputTileTransferInfo +conv_traits_b_transfer_params(int _k1, int kPerBlock = InstTraits::kKPerBlock) +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = kPerBlock / _k1, .m_or_n = InstTraits::kNPerBlock, .k1 = _k1}, + .transfer_params = {.k1 = _k1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; +} + +template +constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma, + .gemm_n = InstTraits::kNPerWmma, + .m_iter = InstTraits::kMRepeat, + .n_iter = InstTraits::kNRepeat}; +} + +template +constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; +} + +template +constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle}, + .thread_cluster_dims = {InstTraits::kCDEThreadClusterLengths[0], + InstTraits::kCDEThreadClusterLengths[1], + InstTraits::kCDEThreadClusterLengths[2], + InstTraits::kCDEThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCDEBlockTransferScalarPerVector}; +} + +template +constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; +} + } // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp index 00010e2d48b..e10baaf7120 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_to_conv_traits.hpp @@ -3,6 +3,18 @@ #pragma once +// Fwd instances #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" #include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +// Bwd weight instances +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp" +#include "ck_tile/builder/reflect/conv_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index cde1896993b..c3a5f9df29d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -62,6 +62,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_multiple_d_Wmma_CShuffle_V3_Tag +{ +}; template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -231,7 +256,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -251,30 +276,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 15. BlockSize - oss << "," << kMPerBlock; // 16. MPerBlock - oss << "," << kNPerBlock; // 17. NPerBlock - oss << "," << kKPerBlock; // 18. KPerBlock - oss << "," << kABK1; // 19. ABK1 - oss << "," << kMPerWmma; // 20. MPerWmma - oss << "," << kNPerWmma; // 21. NPerWmma - oss << "," << kMRepeat; // 22. MRepeat - oss << "," << kNRepeat; // 23. NRepeat + kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kKPerBlock; // 18. KPerBlock + oss << "," << kK1; // 19. ABK1 + oss << "," << kMPerWmma; // 20. MPerWmma + oss << "," << kNPerWmma; // 21. NPerWmma + oss << "," << kMRepeat; // 22. MRepeat + oss << "," << kNRepeat; // 23. NRepeat oss << "," << detail::sequence_name(); // 24. oss << "," << detail::sequence_name(); // 25. oss << "," << detail::sequence_name(); // 26. oss << "," << kABlockTransferSrcVectorDim; // 27. oss << "," << kABlockTransferSrcScalarPerVector; // 28. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 29. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << detail::sequence_name(); // 32. oss << "," << detail::sequence_name(); // 33. oss << "," << kBBlockTransferSrcVectorDim; // 34. oss << "," << kBBlockTransferSrcScalarPerVector; // 35. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 36. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. oss << "," << kCShuffleMRepeatPerShuffle; // 38. oss << "," << kCShuffleNRepeatPerShuffle; // 39. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 6508ac7d6eb..173da8268a3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -59,6 +59,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_multiple_d_Xdl_CShuffle_Tag +{ +}; template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; @@ -211,6 +232,9 @@ struct InstanceTraits< using ComputeTypeA = ComputeTypeA_; using ComputeTypeB = ComputeTypeB_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + // Static member function to generate instance string static std::string instance_string() { @@ -220,7 +244,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -240,30 +264,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 14. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 15. BlockSize - oss << "," << kMPerBlock; // 16. MPerBlock - oss << "," << kNPerBlock; // 17. NPerBlock - oss << "," << kK0PerBlock; // 18. K0PerBlock - oss << "," << kK1; // 19. K1 - oss << "," << kMPerXDL; // 20. MPerXDL - oss << "," << kNPerXDL; // 21. NPerXDL - oss << "," << kMXdlPerWave; // 22. MXdlPerWave - oss << "," << kNXdlPerWave; // 23. NXdlPerWave + kConvBwdWeightSpecialization); // 14. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 15. BlockSize + oss << "," << kMPerBlock; // 16. MPerBlock + oss << "," << kNPerBlock; // 17. NPerBlock + oss << "," << kK0PerBlock; // 18. K0PerBlock + oss << "," << kK1; // 19. K1 + oss << "," << kMPerXDL; // 20. MPerXDL + oss << "," << kNPerXDL; // 21. NPerXDL + oss << "," << kMXdlPerWave; // 22. MXdlPerWave + oss << "," << kNXdlPerWave; // 23. NXdlPerWave oss << "," << detail::sequence_name(); // 24. oss << "," << detail::sequence_name(); // 25. oss << "," << detail::sequence_name(); // 26. oss << "," << kABlockTransferSrcVectorDim; // 27. oss << "," << kABlockTransferSrcScalarPerVector; // 28. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 29. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 30. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 29. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << detail::sequence_name(); // 32. oss << "," << detail::sequence_name(); // 33. oss << "," << kBBlockTransferSrcVectorDim; // 34. oss << "," << kBBlockTransferSrcScalarPerVector; // 35. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 36. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 37. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 36. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 37. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 38. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 39. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index f1e40de7d21..4b90a6ab64d 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag device kernel +struct DeviceGroupedConvBwdWeight_two_stage_Wmma_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; @@ -215,13 +231,26 @@ struct InstanceTraits< static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; + static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -237,7 +266,7 @@ struct InstanceTraits< oss << "DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3"; // Template parameters in exact order - oss << "<" << kNDimSpatial; // 1. NDimSpatial + oss << "<" << kSpatialDim; // 1. NDimSpatial oss << "," << detail::layout_name(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -255,30 +284,30 @@ struct InstanceTraits< // OutElementwiseOperation oss << "," << detail::conv_bwd_weight_spec_name( - kConvBackwardWeightSpecialization); // 12. ConvBackwardWeightSpecialization - oss << "," << kBlockSize; // 13. BlockSize - oss << "," << kMPerBlock; // 14. MPerBlock - oss << "," << kNPerBlock; // 15. NPerBlock - oss << "," << kKPerBlock; // 16. KPerBlock - oss << "," << kABK1; // 17. ABK1 - oss << "," << kMPerWmma; // 18. MPerWmma - oss << "," << kNPerWmma; // 19. NPerWmma - oss << "," << kMRepeat; // 20. MRepeat - oss << "," << kNRepeat; // 21. NRepeat + kConvBwdWeightSpecialization); // 12. ConvBackwardWeightSpecialization + oss << "," << kBlockSize; // 13. BlockSize + oss << "," << kMPerBlock; // 14. MPerBlock + oss << "," << kNPerBlock; // 15. NPerBlock + oss << "," << kKPerBlock; // 16. KPerBlock + oss << "," << kABK1; // 17. ABK1 + oss << "," << kMPerWmma; // 18. MPerWmma + oss << "," << kNPerWmma; // 19. NPerWmma + oss << "," << kMRepeat; // 20. MRepeat + oss << "," << kNRepeat; // 21. NRepeat oss << "," << detail::sequence_name(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 460b49de937..999aff6f1e1 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -63,6 +63,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -234,7 +260,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -252,30 +278,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index f87e295159a..eba422b85f0 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -59,6 +59,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag +{ +}; + template > // Use false to match with the default value { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffle"; + using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_Tag; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -164,15 +170,15 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::LoopScheduler kLoopSched = LoopSched; static constexpr ck::PipelineVersion kPipelineVer = PipelineVer; @@ -216,7 +239,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -234,30 +257,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 29459d67b01..cfc8b4e05af 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -62,6 +62,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_v3 device kernel +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag +{ +}; + template > { static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Wmma_CShuffleV3"; + using device_kernel_tag = DeviceGroupedConvBwdWeight_Wmma_CShuffle_V3_Tag; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -172,13 +178,13 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_AK1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_AK1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_BK0_N_BK1 = BBlockTransferThreadClusterLengths_BK0_N_BK1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_BK1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_BK1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; using CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; - + static constexpr auto kCDEThreadClusterLengths = detail::SequenceToArray< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr int kCDEBlockTransferScalarPerVector = + CShuffleBlockTransferScalarPerVector_NPerBlock; static constexpr ck::BlockGemmPipelineScheduler kBlkGemmPipeSched = BlkGemmPipeSched; static constexpr ck::BlockGemmPipelineVersion kBlkGemmPipelineVer = BlkGemmPipelineVer; @@ -232,7 +257,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -250,30 +275,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_AK1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_BK1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMRepeatPerShuffle; // 36. oss << "," << kCShuffleNRepeatPerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 2c893b9c1dd..1edf03740f3 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_Tag +{ +}; + template ::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; + static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; @@ -224,7 +250,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -242,30 +268,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 147028f9cfb..ce23dac1d7e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -61,6 +61,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; namespace ck_tile { namespace reflect { +/// @brief Tag type for DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 device kernel +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag +{ +}; + template > { + + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3_Tag; static constexpr auto kTensorOpName = "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"; - static constexpr ck::index_t kNDimSpatial = NDimSpatial; + static constexpr ck::index_t kSpatialDim = NDimSpatial; using InLayout = InLayout_; using WeiLayout = WeiLayout_; @@ -167,7 +175,7 @@ struct InstanceTraits::value; + static constexpr auto kAThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kABlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kABlockTransferSrcVectorDim = ABlockTransferSrcVectorDim; static constexpr ck::index_t kABlockTransferSrcScalarPerVector = ABlockTransferSrcScalarPerVector; - static constexpr ck::index_t kABlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kABlockTransferDstScalarPerVectorK1 = ABlockTransferDstScalarPerVector_K1; - static constexpr bool kABlockLdsAddExtraM = ABlockLdsAddExtraM; + static constexpr bool kABlockLdsExtraM = ABlockLdsAddExtraM; using BBlockTransferThreadClusterLengths_K0_N_K1 = BBlockTransferThreadClusterLengths_K0_N_K1_; using BBlockTransferThreadClusterArrangeOrder = BBlockTransferThreadClusterArrangeOrder_; using BBlockTransferSrcAccessOrder = BBlockTransferSrcAccessOrder_; + + // B block transfer thread cluster dimensions (converted to std::array) + static constexpr auto kBThreadClusterLengths = + detail::SequenceToArray::value; + static constexpr auto kBThreadClusterArrangeOrder = + detail::SequenceToArray::value; + static constexpr auto kBBlockTransferSrcAccessOrder = + detail::SequenceToArray::value; + static constexpr ck::index_t kBBlockTransferSrcVectorDim = BBlockTransferSrcVectorDim; static constexpr ck::index_t kBBlockTransferSrcScalarPerVector = BBlockTransferSrcScalarPerVector; - static constexpr ck::index_t kBBlockTransferDstScalarPerVector_K1 = + static constexpr ck::index_t kBBlockTransferDstScalarPerVectorK1 = BBlockTransferDstScalarPerVector_K1; - static constexpr bool kBBlockLdsAddExtraN = BBlockLdsAddExtraN; + static constexpr bool kBBlockLdsExtraN = BBlockLdsAddExtraN; static constexpr ck::index_t kCShuffleMXdlPerWavePerShuffle = CShuffleMXdlPerWavePerShuffle; static constexpr ck::index_t kCShuffleNXdlPerWavePerShuffle = CShuffleNXdlPerWavePerShuffle; using CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_; + static constexpr auto kCThreadClusterLengths = detail::SequenceToArray< + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>::value; static constexpr ck::index_t kCBlockTransferScalarPerVector_NWaveNPerXdl = CBlockTransferScalarPerVector_NWaveNPerXdl; @@ -222,7 +250,7 @@ struct InstanceTraits(); // 2. InLayout oss << "," << detail::layout_name(); // 3. WeiLayout oss << "," << detail::layout_name(); // 4. OutLayout @@ -240,30 +268,30 @@ struct InstanceTraits(); // 22. oss << "," << detail::sequence_name(); // 23. oss << "," << detail::sequence_name(); // 24. oss << "," << kABlockTransferSrcVectorDim; // 25. oss << "," << kABlockTransferSrcScalarPerVector; // 26. - oss << "," << kABlockTransferDstScalarPerVector_K1; // 27. - oss << "," << (kABlockLdsAddExtraM ? "true" : "false"); // 28. + oss << "," << kABlockTransferDstScalarPerVectorK1; // 27. + oss << "," << (kABlockLdsExtraM ? "true" : "false"); // 28. oss << "," << detail::sequence_name(); // 29. oss << "," << detail::sequence_name(); // 30. oss << "," << detail::sequence_name(); // 31. oss << "," << kBBlockTransferSrcVectorDim; // 32. oss << "," << kBBlockTransferSrcScalarPerVector; // 33. - oss << "," << kBBlockTransferDstScalarPerVector_K1; // 34. - oss << "," << (kBBlockLdsAddExtraN ? "true" : "false"); // 35. + oss << "," << kBBlockTransferDstScalarPerVectorK1; // 34. + oss << "," << (kBBlockLdsExtraN ? "true" : "false"); // 35. oss << "," << kCShuffleMXdlPerWavePerShuffle; // 36. oss << "," << kCShuffleNXdlPerWavePerShuffle; // 37. oss << "," diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 782fd158c53..645d75258e0 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -79,6 +79,11 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; } // namespace ck::tensor_operation::device +/// @brief Tag type for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle device kernel +struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag +{ +}; + namespace ck_tile::reflect { // Specialization for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle @@ -176,6 +181,8 @@ struct InstanceTraits> { + /// @brief Tag type identifying this device kernel variant + using device_kernel_tag = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_Tag; // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; diff --git a/experimental/builder/test/conv/ck/test_conv_traits.cpp b/experimental/builder/test/conv/ck/test_conv_traits.cpp index 42235df2fe0..32211135653 100644 --- a/experimental/builder/test/conv/ck/test_conv_traits.cpp +++ b/experimental/builder/test/conv/ck/test_conv_traits.cpp @@ -9,8 +9,17 @@ #include #include #include +#include #include +#include +#include +#include +#include +#include +#include +#include + namespace { using ck_tile::builder::ConvDirection; @@ -26,6 +35,1099 @@ class ConvTraitsTest : public ::testing::Test { }; +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + 3, // NDimSpatial + ck::tensor_layout::convolution::GNDHWC, // InLayout + ck::tensor_layout::convolution::GKZYXC, // WeiLayout + ck::tensor_layout::convolution::GNDHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + 1, // NummGemmKPrefetchStage + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + false>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 3); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleWmmaV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::Tuple<>, // DsLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageWmmaCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 32, // MPerWMMA + 32, // NPerXDL + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + 4, // NumGroupsToMerge + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_two_stage_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightTwoStageXdlCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + 4, // NumGroupsToMerge + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdWeightMultipleDCshuffleXDLTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::Tuple<>, // DsLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::Tuple<>, // DsDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 +TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleV3TraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t>; // BComputeDataType + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvBwdWeight_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvBwdWeightXdlCshuffleTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::BACKWARD_WEIGHT); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + +// test conv traits device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +TEST_F(ConvTraitsTest, ConvFwdTraitsMultipleDCshuffleWmmaExtraction) +{ + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NummGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMRepeatPerShuffle + 1, // CShuffleNRepeatPerShuffle + ck::Sequence< + 1, + 32, + 1, + 8>, // CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEShuffleBlockTransferScalarPerVector_NPerBlock + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1>; // BlkGemmPipelineVer + + // Use ConvTraitsTmpl to extract compile-time information + const auto traits = ck_tile::reflect::conv::instance_to_conv_traits(); + + // Verify signature information + EXPECT_EQ(traits.spatial_dim, 2); + EXPECT_EQ(traits.direction, ConvDirection::FORWARD); + EXPECT_THAT(traits.layout, + ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK)); + EXPECT_EQ(traits.data_type, DataType::FP16); + EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT); + EXPECT_EQ(traits.num_gemm_k_prefetch_stage, 1); + + // Verify algorithm information + EXPECT_EQ(traits.thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(traits.tile_dims.m, 128); + EXPECT_EQ(traits.tile_dims.n, 128); + EXPECT_EQ(traits.tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(traits.warp_gemm.gemm_m, 32); + EXPECT_EQ(traits.warp_gemm.gemm_n, 32); + EXPECT_EQ(traits.warp_gemm.m_iter, 4); + EXPECT_EQ(traits.warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(traits.pipeline_scheduler, + ck_tile::reflect::conv::convert_pipeline_scheduler()); + EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1); +} + // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) { diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index bcea406fa77..4cbde73bcbc 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -259,9 +259,118 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) static constexpr const ConvSignature SIGNATURE; static constexpr const DefaultAlgorithm ALGORITHM; using Instance = ckb::ConvBuilder::Instance; + EXPECT_THAT( + ckr::describe().detailed(), + ckt::StringEqWithDiff( // + "2D Forward Convolution Kernel\n" + "├─ Signature\n" + "│ ├─ Tensor Type: FP16\n" + "│ ├─ Input Layout: GNHWC\n" + "│ ├─ Weight Layout: GKYXC\n" + "│ ├─ Output Layout: GNHWK\n" + "│ ├─ Input elementwise operation: PASS_THROUGH\n" + "│ ├─ Weights elementwise operation: PASS_THROUGH\n" + "│ └─ Output elementwise operation: PASS_THROUGH\n" + "└─ Algorithm\n" + " ├─ Thread block size: 256\n" + " ├─ Data tile size: 256×256×32\n" + " ├─ Gemm padding: DEFAULT\n" + " ├─ Convolution specialization: DEFAULT\n" + " ├─ Pipeline version: V4\n" + " ├─ Pipeline scheduler: INTRAWAVE\n" + " ├─ Warp Gemm parameters: \n" + " │ ├─ subtile size: 16×16\n" + " │ └─ Number of warp gemm iterations: 8×8\n" + " └─ Memory access:\n" + " ├─ A Tile transfer: \n" + " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 2\n" + " │ ├─ Vector access (LDS write) instruction size: 2\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " ├─ B Tile transfer: \n" + " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 2\n" + " │ ├─ Vector access (LDS write) instruction size: 2\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " └─ C Tile transfer: \n" + " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + " ├─ Vector access (GMEM write) instruction size: 2\n" + " ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n" + " ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector " + "parameter\n" + " ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n" + " └─ Struct does not contain optional num_groups_to_merge parameter")); +} + +// Test printing of optional parameters num_groups_to_merge, +// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector +TEST(ConvDescriptionTest, BwdWeightTwoStageWmmaV3DescriptionTest) +{ + using Instance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // InLayout + ck::tensor_layout::convolution::GKYXC, // WeiLayout + ck::tensor_layout::convolution::GNHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // AK1 + 32, // MPerWMMA + 32, // NPerXDL + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + 4, // NumGroupsToMerge + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + 1, // MaxTransposeTransferSrcScalarPerVector + 1>; // MaxTransposeTransferDstScalarPerVector> + EXPECT_THAT(ckr::describe().detailed(), ckt::StringEqWithDiff( // - "2D Forward Convolution Kernel\n" + "2D Backward Weight Convolution Kernel\n" "├─ Signature\n" "│ ├─ Tensor Type: FP16\n" "│ ├─ Input Layout: GNHWC\n" @@ -272,37 +381,146 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) "│ └─ Output elementwise operation: PASS_THROUGH\n" "└─ Algorithm\n" " ├─ Thread block size: 256\n" - " ├─ Data tile size: 256×256×32\n" - " ├─ Gemm padding: DEFAULT\n" + " ├─ Data tile size: 128×128×16\n" + " ├─ Struct does not contain optional gemm_padding argument\n" " ├─ Convolution specialization: DEFAULT\n" - " ├─ Pipeline version: V4\n" - " ├─ Pipeline scheduler: INTRAWAVE\n" + " ├─ Pipeline version: V1\n" + " ├─ Pipeline scheduler: DEFAULT\n" " ├─ Warp Gemm parameters: \n" - " │ ├─ subtile size: 16×16\n" - " │ └─ Number of warp gemm iterations: 8×8\n" + " │ ├─ subtile size: 32×32\n" + " │ └─ Number of warp gemm iterations: 4×4\n" " └─ Memory access:\n" " ├─ A Tile transfer: \n" - " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ Tile dimensions: 2×128×8×\n" " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ ├─ The order of accessing data tile axes: 1×0×2\n" " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 2\n" - " │ ├─ Vector access (LDS write) instruction size: 2\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" " ├─ B Tile transfer: \n" - " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ Tile dimensions: 2×128×8×\n" " │ ├─ The innermost K subdimension size: 8\n" - " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ ├─ The order of accessing data tile axes: 1×0×2\n" " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - " │ ├─ Vector access (GMEM read) instruction size: 2\n" - " │ ├─ Vector access (LDS write) instruction size: 2\n" - " │ └─ LDS data layout padding (to prevent bank conflicts): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" " └─ C Tile transfer: \n" " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - " └─ Vector access (GMEM write) instruction size: 2")); + " ├─ Vector access (GMEM write) instruction size: 8\n" + " ├─ Struct does not contain optional num_gemm_k_prefetch_stage parameter\n" + " ├─ Max Transpose transfer scr scalar per vector: 1\n" + " ├─ Max Transpose dst scalar per vector: 1\n" + " └─ Num groups to merge: 4")); +} + +// Test printing of optional parameters num_groups_to_merge, +// nax_transose_transfer_src_scalar_per_vector and max_transpose_dst_scalar_per_vector +TEST(ConvDescriptionTest, BwdWeightWmmaCshuffleV3DescriptionTest) +{ + using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + 3, // NDimSpatial + ck::tensor_layout::convolution::GNDHWC, // InLayout + ck::tensor_layout::convolution::GKZYXC, // WeiLayout + ck::tensor_layout::convolution::GNDHWK, // OutLayout + ck::half_t, // InDataType + ck::half_t, // WeiDataType + ck::half_t, // OutDataType + float, // AccDataType + ck::tensor_operation::element_wise::PassThrough, // InElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // WeiElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // OutElementwiseOperation + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization:: + Default, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 8, // K1 + 32, // MPerWmma + 32, // NPerWmma + 4, // MRepeat + 4, // NRepeat + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsAddExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder_ + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder_ + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock_ + 8, // CDEBlockTransferScalarPerVector_NPerBlock_ + 1, // NummGemmKPrefetchStage + ck::LoopScheduler::Default, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + false>; // BComputeDataType + + EXPECT_THAT( + ckr::describe().detailed(), + ckt::StringEqWithDiff( // + "3D Backward Weight Convolution Kernel\n" + "├─ Signature\n" + "│ ├─ Tensor Type: FP16\n" + "│ ├─ Input Layout: GNDHWC\n" + "│ ├─ Weight Layout: GKZYXC\n" + "│ ├─ Output Layout: GNDHWK\n" + "│ ├─ Input elementwise operation: PASS_THROUGH\n" + "│ ├─ Weights elementwise operation: PASS_THROUGH\n" + "│ └─ Output elementwise operation: PASS_THROUGH\n" + "└─ Algorithm\n" + " ├─ Thread block size: 256\n" + " ├─ Data tile size: 128×128×16\n" + " ├─ Struct does not contain optional gemm_padding argument\n" + " ├─ Convolution specialization: DEFAULT\n" + " ├─ Pipeline version: V1\n" + " ├─ Pipeline scheduler: DEFAULT\n" + " ├─ Warp Gemm parameters: \n" + " │ ├─ subtile size: 32×32\n" + " │ └─ Number of warp gemm iterations: 4×4\n" + " └─ Memory access:\n" + " ├─ A Tile transfer: \n" + " │ ├─ Tile dimensions: 2×128×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ ├─ The order of accessing data tile axes: 1×0×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " ├─ B Tile transfer: \n" + " │ ├─ Tile dimensions: 2×128×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 1×0×2\n" + " │ ├─ The order of accessing data tile axes: 1×0×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " └─ C Tile transfer: \n" + " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + " ├─ Vector access (GMEM write) instruction size: 8\n" + " ├─ Num gemm k prefetch stage: 1\n" + " ├─ Struct does not contain optional max_transpose_transfer_src_scalar_per_vector " + "parameter\n" + " ├─ Struct does not contain optional max_transpose_dst_scalar_per_vector parameter\n" + " └─ Struct does not contain optional num_groups_to_merge parameter")); } TEST(ConvDescriptionTest, DefaultInstanceHasInstanceString)