diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6e68690048f..8c616ccee20 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -525,31 +525,73 @@ struct is_valid_sequence_map : is_same -struct sequence_map_inverse +// Invert a permutation sequence: given X2Y = {a, b, c, ...}, compute Y2X where Y2X[X2Y[i]] = i +// Example: Sequence<2,0,1> (meaning pos0->2, pos1->0, pos2->1) inverts to Sequence<1,2,0> +// +// Why this implementation is faster to compile than recursive templates: +// +// The old recursive approach created a new template type for each element: +// sequence_map_inverse> -> sequence_map_inverse> -> +// sequence_map_inverse> +// Each "->" is a new type the compiler must create, track, and manage. For N elements, that's +// N template types, each with overhead (name mangling, debug info, symbol table entries). +// +// This implementation uses a different strategy: +// 1. Store the sequence values in a regular array (ConstexprArray) +// 2. Use a normal for-loop (find_inverse) to search the array - runs at compile-time via constexpr +// 3. Use "..." pack expansion to call find_inverse once per position in a single expression +// +// The key insight: a constexpr for-loop compiles to ONE template, while a recursive template +// compiles to N templates. Both do N iterations of work, but the for-loop avoids creating +// N separate types. This reduced compilation time by ~10% on large builds. +namespace detail { +// TODO: Replace with std::array when HIPRTC supports it +// Simple array wrapper that works in constexpr context. Lets us convert the template parameter +// pack (Is...) into an indexable array, so find_inverse() can loop over it. +template +struct ConstexprArray { - template - struct sequence_map_inverse_impl - { - static constexpr auto new_y2x = - WorkingY2X::Modify(X2Y::At(Number{}), Number{}); + T data[N]; - using type = - typename sequence_map_inverse_impl:: - type; - }; + constexpr const T& operator[](index_t i) const { return data[i]; } +}; +} // namespace detail + +template +struct sequence_map_inverse> +{ + private: + // Convert template parameters to array: Sequence<2,0,1> becomes values = {2,0,1} + static constexpr detail::ConstexprArray values = {{Is...}}; - template - struct sequence_map_inverse_impl + // Given a target value, find which position contains it. + // Example: values={2,0,1}, find_inverse(1) returns 2 because values[2]==1 + // This is a regular for-loop, but runs at compile-time because it's constexpr. + static constexpr index_t find_inverse(index_t target) { - using type = WorkingY2X; - }; + for(index_t i = 0; i < static_cast(sizeof...(Is)); ++i) + { + if(values[i] == target) + return i; + } + return -1; // should not reach for valid permutation + } - using type = - typename sequence_map_inverse_impl::type, - 0, - SeqMap::Size()>::type; + // Why we need Positions... instead of just passing the size: + // The "..." syntax expands a parameter pack into repeated expressions. We need a pack + // to expand over. Sequence<0,1,2> gives us Positions = 0,1,2, which expands to: + // Sequence + // Without a pack, we'd need recursion to generate each element - defeating our goal. + template + static constexpr auto compute(Sequence) + { + return Sequence{}; + } + + public: + // make_index_sequence generates Sequence<0,1,2,...,N-1>, giving us the pack to expand. + // Result: find_inverse called for each position 0..N-1, building the inverse sequence. + using type = decltype(compute(make_index_sequence{})); }; template