Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 62 additions & 20 deletions include/ck/utility/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,31 +525,73 @@ struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMa
{
};

template <typename SeqMap>
struct sequence_map_inverse
// Invert a permutation sequence: given X2Y = {a, b, c, ...}, compute Y2X where Y2X[X2Y[i]] = i
// Example: Sequence<2,0,1> (meaning pos0->2, pos1->0, pos2->1) inverts to Sequence<1,2,0>
//
// Why this implementation is faster to compile than recursive templates:
//
// The old recursive approach created a new template type for each element:
// sequence_map_inverse<Seq<2,0,1>> -> sequence_map_inverse<Seq<0,1>> ->
// sequence_map_inverse<Seq<1>>
// Each "->" is a new type the compiler must create, track, and manage. For N elements, that's
// N template types, each with overhead (name mangling, debug info, symbol table entries).
//
// This implementation uses a different strategy:
// 1. Store the sequence values in a regular array (ConstexprArray)
// 2. Use a normal for-loop (find_inverse) to search the array - runs at compile-time via constexpr
// 3. Use "..." pack expansion to call find_inverse once per position in a single expression
//
// The key insight: a constexpr for-loop compiles to ONE template, while a recursive template
// compiles to N templates. Both do N iterations of work, but the for-loop avoids creating
// N separate types. This reduced compilation time by ~10% on large builds.
namespace detail {
// TODO: Replace with std::array when HIPRTC supports it
// Simple array wrapper that works in constexpr context. Lets us convert the template parameter
// pack (Is...) into an indexable array, so find_inverse() can loop over it.
template <typename T, index_t N>
struct ConstexprArray
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
T data[N];

using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
constexpr const T& operator[](index_t i) const { return data[i]; }
};
} // namespace detail

template <index_t... Is>
struct sequence_map_inverse<Sequence<Is...>>
{
private:
// Convert template parameters to array: Sequence<2,0,1> becomes values = {2,0,1}
static constexpr detail::ConstexprArray<index_t, sizeof...(Is)> values = {{Is...}};

template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
// Given a target value, find which position contains it.
// Example: values={2,0,1}, find_inverse(1) returns 2 because values[2]==1
// This is a regular for-loop, but runs at compile-time because it's constexpr.
Comment on lines +567 to +569
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'which' to 'witch' in the comment.

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting suggestion but nope

static constexpr index_t find_inverse(index_t target)
{
using type = WorkingY2X;
};
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
{
if(values[i] == target)
return i;
}
return -1; // should not reach for valid permutation
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value -1 for an invalid permutation is misleading since index_t is likely unsigned. Consider using a static_assert or explicit error handling to catch invalid permutations at compile-time, or document that this path should never execute for valid inputs.

Suggested change
return -1; // should not reach for valid permutation
return static_cast<index_t>(-1); // should not reach for valid permutation

Copilot uses AI. Check for mistakes.
}

using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
SeqMap::Size()>::type;
// Why we need Positions... instead of just passing the size:
// The "..." syntax expands a parameter pack into repeated expressions. We need a pack
// to expand over. Sequence<0,1,2> gives us Positions = 0,1,2, which expands to:
// Sequence<find_inverse(0), find_inverse(1), find_inverse(2)>
// Without a pack, we'd need recursion to generate each element - defeating our goal.
template <index_t... Positions>
static constexpr auto compute(Sequence<Positions...>)
{
return Sequence<find_inverse(Positions)...>{};
}

public:
// make_index_sequence<N> generates Sequence<0,1,2,...,N-1>, giving us the pack to expand.
// Result: find_inverse called for each position 0..N-1, building the inverse sequence.
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};

template <index_t... Xs, index_t... Ys>
Expand Down