Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 80 additions & 28 deletions include/ck/tensor_description/tensor_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ struct TensorDescriptor

__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});

constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});

constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);

Expand Down Expand Up @@ -311,6 +309,73 @@ struct lambda_get_up_dim_num
}
};

// Named functors for tensor_descriptor transformations - reduce template instantiations
//
// Problem: Nested lambdas in transform_tensor_descriptor create multiple unique types
// - Inner lambda: [&](auto low_dim_visible_id) { return visible_ids.At(low_dim_visible_id); }
// - Outer lambda: [&](auto ids) { return transform_sequences(inner_lambda, ids); }
// - Each lambda capture combination creates unique type → many instantiations
//
// Solution: Named functor structs
// - Single reusable type per operation
// - No capture-dependent instantiations
//
// Impact: Significantly reduces tensor_descriptor template instantiation count
//
// convert_visible_to_hidden_id - maps single visible dimension ID to hidden ID
//
// Replaces: [&](auto low_dim_visible_id) { return old_visible_ids.At(low_dim_visible_id); }
//
// Note: transform_sequences passes index_t values, not Number<> types
//
template <typename OldTensorDescriptor>
struct convert_visible_to_hidden_id
{
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
{
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
}
};

// convert_visible_ids_to_hidden_ids - maps sequence of visible IDs to hidden IDs
//
// Replaces: [&](auto low_dim_visible_ids) { return transform_sequences(convert_fn, ids); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment about the lambda it replaces is obvious and not helpful going forward (since the lambda isn't around anymore).

//
// Uses convert_visible_to_hidden_id functor to transform each element in the sequence
//
template <typename OldTensorDescriptor>
struct convert_visible_ids_to_hidden_ids
{
template <typename LowDimVisibleIds>
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
{
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
low_dim_visible_ids);
}
};

// generate_arithmetic_sequence_from_scan - generates arithmetic sequences for upper dimensions
//
// Replaces lambda: [&](auto i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly: this "Replaces ..." comment is an artifact of the refactoring and should be removed.

// constexpr index_t start = old_hidden + scan.At(i);
// constexpr index_t end = old_hidden + scan.At(i+1);
// return arithmetic_sequence_gen<start, end, 1>{};
// }
//
// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions
//
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
struct generate_arithmetic_sequence_from_scan
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
return typename arithmetic_sequence_gen<start, end, 1>::type{};
}
};

template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
Expand All @@ -327,11 +392,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");

constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_old_top_ids =
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});

constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto all_new_top_ids =
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});

static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
Expand All @@ -341,17 +406,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr auto low_dim_hidden_idss =
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
NewLowerDimensionOldVisibleIdss{});

constexpr index_t num_new_transform = NewTransforms::Size();

Expand All @@ -364,22 +421,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));

using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
constexpr auto up_dim_hidden_idss = generate_tuple(
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
Number<num_new_transform>{});

// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
unpack_and_merge_sequences(up_dim_hidden_idss);

constexpr auto new_visible_dim_unordered2ordered =
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});

constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
Expand Down
7 changes: 2 additions & 5 deletions include/ck/tensor_operation/gpu/device/matrix_padder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
},
Number<num_dim>{});

// lower dimension Id
const auto lower_dimss =
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});

// upper dimension Id
// lower/upper dimension Ids
const auto lower_dimss = generate_identity_sequences<num_dim>();
const auto upper_dimss = lower_dimss;

return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -866,8 +866,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -925,8 +924,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -706,8 +705,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down Expand Up @@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
},
Number<nDim>{});

constexpr auto up_dim_idss =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();

return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
Expand Down
39 changes: 39 additions & 0 deletions include/ck/utility/sequence_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,43 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}

// Named functors for sequence operations - optimized to reduce template instantiations
//
// Problem: Using lambdas with unpack creates unique types at each call site
// - Example: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple)
// - Each call site generates a new lambda type → multiple instantiations
//
// Solution: Named functor (merge_sequences_functor)
// - Single reusable type across all call sites
// - Eliminates per-call lambda instantiation overhead
//
// Impact: Significantly reduces template instantiations in tensor_descriptor operations
//
struct merge_sequences_functor
{
template <typename... Seqs>
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
{
return merge_sequences(seqs...);
}
};

// unpack_and_merge_sequences - unpacks tuple of sequences and merges them
//
// Optimization: Uses named functor instead of lambda with unpack
//
// Why this approach:
// - Old: unpack([](auto... xs) { return merge_sequences(xs...); }, tuple_of_sequences)
// Creates unique lambda type at each call site
// - New: unpack(merge_sequences_functor{}, tuple_of_sequences)
// Reuses single functor type across all call sites
//
// Use case: Common pattern in tensor_descriptor for merging dimension ID sequences
//
template <typename TupleOfSequences>
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences)
{
return unpack(merge_sequences_functor{}, TupleOfSequences{});
Comment on lines +71 to +73
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unpack_and_merge_sequences ignores its argument and always default-constructs TupleOfSequences{}. That’s surprising and will fail for non-default-constructible tuple-like types. Use the passed tuple value (e.g., forward it into unpack) so the helper behaves like its name suggests.

Suggested change
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences)
{
return unpack(merge_sequences_functor{}, TupleOfSequences{});
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences)
{
return unpack(merge_sequences_functor{}, tuple_of_sequences);

Copilot uses AI. Check for mistakes.
}
Comment on lines +70 to +74
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unpack_and_merge_sequences calls unpack(...), but this header only includes ck/utility/tuple.hpp and does not include the header that defines unpack (ck/utility/functional4.hpp). This breaks includes that pull in sequence_helper.hpp before functional4.hpp (e.g. include/ck/utility/container_helper.hpp includes sequence_helper.hpp before tuple_helper.hpp). Add the proper include here (or otherwise ensure unpack is declared).

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to be careful to explicitly include the templates we use and not rely on transitive inclusion.


} // namespace ck
Loading