Skip to content

Commit c5ec2f7

Browse files
tenpercentclaude
andauthored
Add generate_identity_sequences helper and replace lambdas with named functors (#4828)
## Summary - Add `generate_identity_sequences<N>()` helper that returns `Tuple<Sequence<0>, Sequence<1>, ..., Sequence<N-1>>` - Replace lambdas with named functors in `transform_tensor_descriptor` - Add `unpack_and_merge_sequences` helper functor - Reduces `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction) ## Motivation Multiple call sites use `generate_tuple([](auto i) { return Sequence<i>{}; }, Number<N>{})` pattern. A named helper reduces lambda instantiations. Additionally, each lambda in `transform_tensor_descriptor` creates a unique closure type, causing the function to be instantiated separately for every call site. Named functors share a single type, so the compiler reuses the same instantiation. ## Changes ### Part 1: generate_identity_sequences helper - Replaces common lambda pattern for generating identity sequences - Each lambda expression creates a unique closure type, causing separate template instantiations at every call site - Named helper shares a single type across all uses ### Part 2: Named functors in transform_tensor_descriptor - Add `unpack_and_merge_sequences` helper to replace lambda in `GetNumOfHiddenDimension` - Use `generate_identity_sequences` in `matrix_padder.hpp` ## Test Plan - [x] Added 7 unit tests: - 4 tests for `generate_identity_sequences` - 3 tests for `unpack_and_merge_sequences` - [ ] Waiting for full CI ## Related PRs This PR merges the functionality from: - ROCm#3588 (generate_identity_sequences helper) - ROCm#3589 (Named functors in transform_tensor_descriptor) Part of PR stack for issue #4229 (Reduce CK/CKTile Build Times) **Note:** This PR supersedes #4283, ROCm#3588 and ROCm#3589, which can be closed once this is merged. --- 🔁 Imported from [ROCm#3628](ROCm#3628) 🧑‍💻 Originally authored by @tenpercent Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 68dff9d commit c5ec2f7

19 files changed

Lines changed: 550 additions & 74 deletions

include/ck/tensor_description/tensor_descriptor.hpp

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,9 @@ struct TensorDescriptor
3838

3939
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
4040
{
41-
constexpr auto all_low_dim_ids = unpack(
42-
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
41+
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
4342

44-
constexpr auto all_up_dim_ids = unpack(
45-
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
43+
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
4644

4745
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
4846

@@ -319,6 +317,41 @@ struct lambda_get_up_dim_num
319317
}
320318
};
321319

320+
// Maps a visible dimension ID to its corresponding hidden dimension ID
321+
template <typename OldTensorDescriptor>
322+
struct convert_visible_to_hidden_id
323+
{
324+
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
325+
{
326+
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
327+
}
328+
};
329+
330+
// Maps a sequence of visible IDs to their corresponding hidden IDs
331+
template <typename OldTensorDescriptor>
332+
struct convert_visible_ids_to_hidden_ids
333+
{
334+
template <typename LowDimVisibleIds>
335+
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
336+
{
337+
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
338+
low_dim_visible_ids);
339+
}
340+
};
341+
342+
// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions
343+
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
344+
struct generate_arithmetic_sequence_from_scan
345+
{
346+
template <typename I>
347+
__host__ __device__ constexpr auto operator()(I) const
348+
{
349+
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
350+
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
351+
return typename arithmetic_sequence_gen<start, end, 1>::type{};
352+
}
353+
};
354+
322355
template <typename OldTensorDescriptor,
323356
typename NewTransforms,
324357
typename NewLowerDimensionOldVisibleIdss,
@@ -335,11 +368,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
335368
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
336369
"wrong! inconsitent number of transform");
337370

338-
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
339-
NewLowerDimensionOldVisibleIdss{});
371+
constexpr auto all_old_top_ids =
372+
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
340373

341-
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
342-
NewUpperDimensionNewVisibleIdss{});
374+
constexpr auto all_new_top_ids =
375+
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
343376

344377
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
345378
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
@@ -349,17 +382,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
349382
// lower dimension's hidden idss
350383
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
351384
// sequences)
352-
constexpr auto low_dim_hidden_idss = transform_tuples(
353-
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
354-
[](auto low_dim_visible_ids) constexpr {
355-
return transform_sequences(
356-
// convert lower dimension visible id to hidden id
357-
[](auto low_dim_visible_id) constexpr {
358-
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
359-
},
360-
low_dim_visible_ids);
361-
},
362-
NewLowerDimensionOldVisibleIdss{});
385+
constexpr auto low_dim_hidden_idss =
386+
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
387+
NewLowerDimensionOldVisibleIdss{});
363388

364389
constexpr index_t num_new_transform = NewTransforms::Size();
365390

@@ -372,22 +397,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
372397
constexpr auto up_dim_numbers_scan = merge_sequences(
373398
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
374399

400+
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
375401
constexpr auto up_dim_hidden_idss = generate_tuple(
376-
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
377-
return
378-
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
379-
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
380-
1>::type{};
381-
},
402+
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
382403
Number<num_new_transform>{});
383404

384405
// new visible dimension's hidden ids
385406
constexpr auto unordered_new_visible_dim_hidden_ids =
386-
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
407+
unpack_and_merge_sequences(up_dim_hidden_idss);
387408

388409
constexpr auto new_visible_dim_unordered2ordered =
389-
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
390-
NewUpperDimensionNewVisibleIdss{});
410+
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
391411

392412
constexpr auto new_visible_dim_hidden_ids =
393413
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);

include/ck/tensor_operation/gpu/device/matrix_padder.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
4343
},
4444
Number<num_dim>{});
4545

46-
// lower dimension Id
47-
const auto lower_dimss =
48-
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
49-
50-
// upper dimension Id
46+
// lower/upper dimension Ids
47+
const auto lower_dimss = generate_identity_sequences<num_dim>();
5148
const auto upper_dimss = lower_dimss;
5249

5350
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
739739
},
740740
Number<nDim>{});
741741

742-
constexpr auto up_dim_idss =
743-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
742+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
744743

745744
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
746745
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
894894
},
895895
Number<nDim>{});
896896

897-
constexpr auto up_dim_idss =
898-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
897+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
899898

900899
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
901900
}
@@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
944943
},
945944
Number<nDim>{});
946945

947-
constexpr auto up_dim_idss =
948-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
946+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
949947

950948
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
951949
}
@@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
993991
},
994992
Number<nDim>{});
995993

996-
constexpr auto up_dim_idss =
997-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
994+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
998995

999996
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
1000997
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
833833
},
834834
Number<nDim>{});
835835

836-
constexpr auto up_dim_idss =
837-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
836+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
838837

839838
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
840839
}
@@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
892891
},
893892
Number<nDim>{});
894893

895-
constexpr auto up_dim_idss =
896-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
894+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
897895

898896
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
899897
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
692692
},
693693
Number<nDim>{});
694694

695-
constexpr auto up_dim_idss =
696-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
695+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
697696

698697
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
699698
}
@@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
744743
},
745744
Number<nDim>{});
746745

747-
constexpr auto up_dim_idss =
748-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
746+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
749747

750748
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
751749
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
514514
},
515515
Number<nDim>{});
516516

517-
constexpr auto up_dim_idss =
518-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
517+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
519518

520519
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
521520
}
@@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
563562
},
564563
Number<nDim>{});
565564

566-
constexpr auto up_dim_idss =
567-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
565+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
568566

569567
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
570568
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
657657
},
658658
Number<nDim>{});
659659

660-
constexpr auto up_dim_idss =
661-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
660+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
662661

663662
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
664663
}
@@ -707,8 +706,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
707706
},
708707
Number<nDim>{});
709708

710-
constexpr auto up_dim_idss =
711-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
709+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
712710

713711
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
714712
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
548548
},
549549
Number<nDim>{});
550550

551-
constexpr auto up_dim_idss =
552-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
551+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
553552

554553
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
555554
}
@@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
598597
},
599598
Number<nDim>{});
600599

601-
constexpr auto up_dim_idss =
602-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
600+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
603601

604602
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
605603
}

include/ck/utility/sequence_helper.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include "ck/utility/functional4.hpp"
67
#include "ck/utility/tuple.hpp"
78

89
namespace ck {
@@ -34,4 +35,21 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
3435
return Sequence<Is...>{};
3536
}
3637

38+
// Functor wrapper for merge_sequences to enable reuse across call sites
39+
struct merge_sequences_functor
40+
{
41+
template <typename... Seqs>
42+
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
43+
{
44+
return merge_sequences(seqs...);
45+
}
46+
};
47+
48+
// Unpacks tuple of sequences and merges them into a single sequence
49+
template <typename TupleOfSequences>
50+
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences)
51+
{
52+
return unpack(merge_sequences_functor{}, tuple_of_sequences);
53+
}
54+
3755
} // namespace ck

0 commit comments

Comments
 (0)