-
Notifications
You must be signed in to change notification settings - Fork 269
Add generate_identity_sequences helper and replace lambdas with named functors #3628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
b8cd896
7c3ab8a
905ae13
84a09ef
8031a96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,11 +36,9 @@ struct TensorDescriptor | |
|
|
||
| __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() | ||
| { | ||
| constexpr auto all_low_dim_ids = unpack( | ||
| [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); | ||
| constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{}); | ||
|
|
||
| constexpr auto all_up_dim_ids = unpack( | ||
| [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); | ||
| constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{}); | ||
|
|
||
| constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); | ||
|
|
||
|
|
@@ -311,6 +309,73 @@ struct lambda_get_up_dim_num | |
| } | ||
| }; | ||
|
|
||
| // Named functors for tensor_descriptor transformations - reduce template instantiations | ||
| // | ||
| // Problem: Nested lambdas in transform_tensor_descriptor create multiple unique types | ||
| // - Inner lambda: [&](auto low_dim_visible_id) { return visible_ids.At(low_dim_visible_id); } | ||
| // - Outer lambda: [&](auto ids) { return transform_sequences(inner_lambda, ids); } | ||
| // - Each lambda capture combination creates unique type → many instantiations | ||
| // | ||
| // Solution: Named functor structs | ||
| // - Single reusable type per operation | ||
| // - No capture-dependent instantiations | ||
| // | ||
| // Impact: Significantly reduces tensor_descriptor template instantiation count | ||
| // | ||
| // convert_visible_to_hidden_id - maps single visible dimension ID to hidden ID | ||
| // | ||
| // Replaces: [&](auto low_dim_visible_id) { return old_visible_ids.At(low_dim_visible_id); } | ||
| // | ||
| // Note: transform_sequences passes index_t values, not Number<> types | ||
| // | ||
| template <typename OldTensorDescriptor> | ||
| struct convert_visible_to_hidden_id | ||
| { | ||
| __host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const | ||
| { | ||
| return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id); | ||
| } | ||
| }; | ||
|
|
||
| // convert_visible_ids_to_hidden_ids - maps sequence of visible IDs to hidden IDs | ||
| // | ||
| // Replaces: [&](auto low_dim_visible_ids) { return transform_sequences(convert_fn, ids); } | ||
| // | ||
| // Uses convert_visible_to_hidden_id functor to transform each element in the sequence | ||
| // | ||
| template <typename OldTensorDescriptor> | ||
| struct convert_visible_ids_to_hidden_ids | ||
| { | ||
| template <typename LowDimVisibleIds> | ||
| __host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const | ||
| { | ||
| return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{}, | ||
| low_dim_visible_ids); | ||
| } | ||
| }; | ||
|
|
||
| // generate_arithmetic_sequence_from_scan - generates arithmetic sequences for upper dimensions | ||
| // | ||
| // Replaces lambda: [&](auto i) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly: this "Replaces ..." comment is an artifact of the refactoring and should be removed. |
||
| // constexpr index_t start = old_hidden + scan.At(i); | ||
| // constexpr index_t end = old_hidden + scan.At(i+1); | ||
| // return arithmetic_sequence_gen<start, end, 1>{}; | ||
| // } | ||
| // | ||
| // Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions | ||
| // | ||
| template <index_t OldHiddenDimNumber, typename UpDimNumbersScan> | ||
| struct generate_arithmetic_sequence_from_scan | ||
| { | ||
| template <typename I> | ||
| __host__ __device__ constexpr auto operator()(I) const | ||
| { | ||
| constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{}); | ||
| constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{}); | ||
| return typename arithmetic_sequence_gen<start, end, 1>::type{}; | ||
| } | ||
| }; | ||
|
|
||
| template <typename OldTensorDescriptor, | ||
| typename NewTransforms, | ||
| typename NewLowerDimensionOldVisibleIdss, | ||
|
|
@@ -327,11 +392,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, | |
| NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(), | ||
| "wrong! inconsitent number of transform"); | ||
|
|
||
| constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, | ||
| NewLowerDimensionOldVisibleIdss{}); | ||
| constexpr auto all_old_top_ids = | ||
| unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{}); | ||
|
|
||
| constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, | ||
| NewUpperDimensionNewVisibleIdss{}); | ||
| constexpr auto all_new_top_ids = | ||
| unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{}); | ||
|
|
||
| static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value && | ||
| is_valid_sequence_map<decltype(all_new_top_ids)>::value, | ||
|
|
@@ -341,17 +406,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, | |
| // lower dimension's hidden idss | ||
| // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of | ||
| // sequences) | ||
| constexpr auto low_dim_hidden_idss = transform_tuples( | ||
| // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) | ||
| [](auto low_dim_visible_ids) constexpr { | ||
| return transform_sequences( | ||
| // convert lower dimension visible id to hidden id | ||
| [](auto low_dim_visible_id) constexpr { | ||
| return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; | ||
| }, | ||
| low_dim_visible_ids); | ||
| }, | ||
| NewLowerDimensionOldVisibleIdss{}); | ||
| constexpr auto low_dim_hidden_idss = | ||
| transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{}, | ||
| NewLowerDimensionOldVisibleIdss{}); | ||
|
|
||
| constexpr index_t num_new_transform = NewTransforms::Size(); | ||
|
|
||
|
|
@@ -364,22 +421,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, | |
| constexpr auto up_dim_numbers_scan = merge_sequences( | ||
| Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{})); | ||
|
|
||
| using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>; | ||
| constexpr auto up_dim_hidden_idss = generate_tuple( | ||
| [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { | ||
| return | ||
| typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i], | ||
| old_hidden_dim_number + up_dim_numbers_scan[i + 1], | ||
| 1>::type{}; | ||
| }, | ||
| generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{}, | ||
| Number<num_new_transform>{}); | ||
|
|
||
| // new visible dimension's hidden ids | ||
| constexpr auto unordered_new_visible_dim_hidden_ids = | ||
| unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); | ||
| unpack_and_merge_sequences(up_dim_hidden_idss); | ||
|
|
||
| constexpr auto new_visible_dim_unordered2ordered = | ||
| unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, | ||
| NewUpperDimensionNewVisibleIdss{}); | ||
| unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{}); | ||
|
|
||
| constexpr auto new_visible_dim_hidden_ids = | ||
| unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -34,4 +34,43 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>) | |||||||||||||
| return Sequence<Is...>{}; | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
| // Named functors for sequence operations - optimized to reduce template instantiations | ||||||||||||||
| // | ||||||||||||||
| // Problem: Using lambdas with unpack creates unique types at each call site | ||||||||||||||
| // - Example: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple) | ||||||||||||||
| // - Each call site generates a new lambda type → multiple instantiations | ||||||||||||||
| // | ||||||||||||||
| // Solution: Named functor (merge_sequences_functor) | ||||||||||||||
| // - Single reusable type across all call sites | ||||||||||||||
| // - Eliminates per-call lambda instantiation overhead | ||||||||||||||
| // | ||||||||||||||
| // Impact: Significantly reduces template instantiations in tensor_descriptor operations | ||||||||||||||
| // | ||||||||||||||
| struct merge_sequences_functor | ||||||||||||||
| { | ||||||||||||||
| template <typename... Seqs> | ||||||||||||||
| __host__ __device__ constexpr auto operator()(Seqs... seqs) const | ||||||||||||||
| { | ||||||||||||||
| return merge_sequences(seqs...); | ||||||||||||||
| } | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| // unpack_and_merge_sequences - unpacks tuple of sequences and merges them | ||||||||||||||
| // | ||||||||||||||
| // Optimization: Uses named functor instead of lambda with unpack | ||||||||||||||
| // | ||||||||||||||
| // Why this approach: | ||||||||||||||
| // - Old: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple_of_sequences) | ||||||||||||||
| // Creates unique lambda type at each call site | ||||||||||||||
| // - New: unpack(merge_sequences_functor{}, tuple_of_sequences) | ||||||||||||||
| // Reuses single functor type across all call sites | ||||||||||||||
| // | ||||||||||||||
| // Use case: Common pattern in tensor_descriptor for merging dimension ID sequences | ||||||||||||||
| // | ||||||||||||||
| template <typename TupleOfSequences> | ||||||||||||||
| __host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences) | ||||||||||||||
| { | ||||||||||||||
| return unpack(merge_sequences_functor{}, TupleOfSequences{}); | ||||||||||||||
|
Comment on lines
+71
to
+73
|
||||||||||||||
| __host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences) | |
| { | |
| return unpack(merge_sequences_functor{}, TupleOfSequences{}); | |
| __host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences) | |
| { | |
| return unpack(merge_sequences_functor{}, tuple_of_sequences); |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unpack_and_merge_sequences calls unpack(...), but this header only includes ck/utility/tuple.hpp and does not include the header that defines unpack (ck/utility/functional4.hpp). This breaks includes that pull in sequence_helper.hpp before functional4.hpp (e.g. include/ck/utility/container_helper.hpp includes sequence_helper.hpp before tuple_helper.hpp). Add the proper include here (or otherwise ensure unpack is declared).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need to be careful to explicitly include the templates we use and not rely on transitive inclusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment about the lambda it replaces is obvious and not helpful going forward (since the lambda isn't around anymore).