@@ -6302,6 +6302,92 @@ LogicalResult tpu_vector_store_rule(RewriteContext &ctx, Operation &op,
63026302 store_op.getMask ());
63036303}
63046304
6305+ namespace {
6306+
6307+ // Structure to hold sublane count and derived values.
6308+ struct Sublane {
6309+ const int count;
6310+ const int half;
6311+ const int quarter;
6312+ const int octa;
6313+ explicit Sublane (int sc)
6314+ : count(sc), half(sc / 2 ), quarter(sc / 4 ), octa(sc / 8 ) {}
6315+ };
6316+
6317+ // Structure to hold pairs of low and high pattern vectors.
6318+ struct HighLowPatterns {
6319+ std::vector<int > low;
6320+ std::vector<int > high;
6321+ explicit HighLowPatterns (int size) : low(size), high(size) {}
6322+ };
6323+
6324+ // Helper to create combine patterns.
6325+ HighLowPatterns CreateCombinePatterns (const Sublane& sublane) {
6326+ HighLowPatterns patterns (sublane.count );
6327+ // For sublane.count = 8, low = {0, 1, 2, 3, 8, 9, 10, 11}, high = {4, 5, 6,
6328+ // 7, 12, 13, 14, 15}.
6329+ for (int i = 0 ; i < sublane.half ; ++i) {
6330+ patterns.low [i] = i;
6331+ patterns.low [i + sublane.half ] = i + sublane.count ;
6332+ }
6333+ absl::c_transform (patterns.low , patterns.high .begin (),
6334+ [sublane](int value) { return value + sublane.half ; });
6335+ return patterns;
6336+ }
6337+
6338+ // Helper to create shuffle patterns for Stage 0.
6339+ HighLowPatterns CreateStage0ShufflePatterns (const Sublane& sublane) {
6340+ HighLowPatterns patterns (sublane.count );
6341+ for (int i = 0 ; i < sublane.quarter ; ++i) {
6342+ patterns.low [i] = i;
6343+ patterns.low [i + sublane.quarter ] = i + sublane.half ;
6344+ patterns.low [i + sublane.half ] = i + sublane.quarter ;
6345+ patterns.low [i + sublane.half + sublane.quarter ] =
6346+ i + sublane.half + sublane.quarter ;
6347+ }
6348+ absl::c_transform (patterns.low , patterns.high .begin (),
6349+ [sublane](int value) { return value + sublane.count ; });
6350+ return patterns;
6351+ }
6352+
6353+ // Helper to create shuffle patterns for Stage 1.
6354+ HighLowPatterns CreateStage1ShufflePatterns (const Sublane& sublane) {
6355+ HighLowPatterns patterns (sublane.count );
6356+ // For sublane.count = 8, low = {0, 1, 4, 5, 2, 3, 6, 7}, high = {8, 9, 12,
6357+ // 13, 10, 11, 14, 15}.
6358+ for (int i = 0 ; i < sublane.octa ; ++i) {
6359+ patterns.low [4 * i] = 4 * i;
6360+ patterns.low [4 * i + 1 ] = 4 * i + 1 ;
6361+ patterns.low [4 * i + 2 ] = 4 * i + sublane.half ;
6362+ patterns.low [4 * i + 3 ] = 4 * i + sublane.half + 1 ;
6363+ patterns.low [4 * i + sublane.half ] = 4 * i + 2 ;
6364+ patterns.low [4 * i + sublane.half + 1 ] = 4 * i + 3 ;
6365+ patterns.low [4 * i + sublane.half + 2 ] = 4 * i + sublane.half + 2 ;
6366+ patterns.low [4 * i + sublane.half + 3 ] = 4 * i + sublane.half + 3 ;
6367+ }
6368+ absl::c_transform (patterns.low , patterns.high .begin (),
6369+ [sublane](int value) { return value + sublane.count ; });
6370+ return patterns;
6371+ }
6372+
6373+ // Helper to create shuffle patterns for Stage 2.
6374+ HighLowPatterns CreateStage2ShufflePatterns (const Sublane& sublane) {
6375+ HighLowPatterns patterns (sublane.count );
6376+ // For sublane.count = 8, low = {0, 4, 2, 6, 1, 5, 3, 7}, high = {8, 12, 10,
6377+ // 14, 9, 13, 11, 15}.
6378+ for (int i = 0 ; i < sublane.quarter ; ++i) {
6379+ patterns.low [2 * i] = 2 * i;
6380+ patterns.low [2 * i + 1 ] = 2 * i + sublane.half ;
6381+ patterns.low [2 * i + sublane.half ] = 2 * i + 1 ;
6382+ patterns.low [2 * i + sublane.half + 1 ] = 2 * i + sublane.half + 1 ;
6383+ }
6384+ absl::c_transform (patterns.low , patterns.high .begin (),
6385+ [sublane](int value) { return value + sublane.count ; });
6386+ return patterns;
6387+ }
6388+
6389+ } // namespace
6390+
63056391LogicalResult vector_transpose_rule (RewriteContext &ctx, Operation &op,
63066392 const ArrayRef<Layout> layouts_in,
63076393 const ArrayRef<Layout> layouts_out) {
@@ -6351,14 +6437,17 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
63516437 // This is a 3 stage algorithm that uses combinations and shuffles to do a
63526438 // transposition of an 8x8 or 4x8 or 2x8 block of sublanes. We use the
63536439 // 32-bit 8x8 case as an example. For 4x8 and 2x8, we just need to run fewer
6354- // rounds of the algorithm. For packed dtypes, we will essentially working
6355- // on (8*packing)x8 blocks, and repeat the 8x8 algorithm `packing` times,
6356- // the first 8x8 block contains the `packing * i`th input vregs, the second
6357- // 8x8 block contains the `packing * i + 1`th input vregs, etc. Take 16-bit
6358- // as an example, we have 16 vregs in total, they are [V0, V1,..., V15]. We
6359- // view [V0, V2,...,V14] as the first 8x8 block, and [V1, V3,...,V15] as the
6360- // second 8x8 block. We also need to do an extra unpacking and repacking
6361- // step for packed dtypes after the 3 stage algorithm.
6440+ // rounds of the algorithm. This algorithm can also be generalized to
6441+ // support arbitrary sublane count. For sublane count larger than 8, just
6442+ // need to run more rounds. For packed dtypes, we will essentially be
6443+ // working on (8*packing)x8 blocks, and repeat the 8x8 algorithm `packing`
6444+ // times, the first 8x8 block contains the `packing * i`th input vregs, the
6445+ // second 8x8 block contains the `packing * i + 1`th input vregs, etc. Take
6446+ // 16-bit as an example, we have 16 vregs in total, they are [V0, V1,...,
6447+ // V15]. We view [V0, V2,...,V14] as the first 8x8 block, and [V1,
6448+ // V3,...,V15] as the second 8x8 block. We also need to do an extra
6449+ // unpacking and repacking step for packed dtypes after the 3 stage
6450+ // algorithm.
63626451
63636452 // In the following algorithm description, A, B, ..., H represent 8 distinct
63646453 // input vregs that form an 8x8 block of data to be transposed. In our
@@ -6458,11 +6547,6 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
64586547 if (layout_in.offsets () != LayoutOffsets{0 , 0 }) {
64596548 return transpose_op.emitOpError (" Not implemented: Layout with offset." );
64606549 }
6461- // TODO(b/456173864): Relax this constraint.
6462- if (ctx.target_shape [0 ] != 8 ) {
6463- return transpose_op.emitOpError (
6464- " Not implemented: Major-second-minor transpose expects 8 sublanes." );
6465- }
64666550
64676551 {
64686552 // Transpose 4th+ minors if applicable.
@@ -6475,6 +6559,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
64756559
64766560 const int packing = layout_in.packing ();
64776561 const int bitwidth = layout_in.bitwidth ();
6562+ const Sublane sublane (ctx.target_shape [0 ]);
64786563 const int64_t sublane_tiling = layout_in.tiling ()[0 ];
64796564
64806565 int64_t src_vregs_dim = src_vregs.num_dimensions ();
@@ -6513,9 +6598,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
65136598 // we cannot just resolve that in our outer loop is because of the nature
65146599 // of a transpose - this dim value goes unmultiplied into the output vregs.
65156600 // effectively, our indexing:
6516- // {major_dim_slice_idx * sublane_count , second_minor_dim_slice_idx,
6601+ // {major_dim_slice_idx * sublane_tiling , second_minor_dim_slice_idx,
65176602 // minor_most_dim_slice_idx} becomes {second_minor_dim_slice_idx *
6518- // sublane_count , major_dim_slice_idx, minor_most_dim_slice_idx}
6603+ // sublane_tiling , major_dim_slice_idx, minor_most_dim_slice_idx}
65196604 const int64_t major_dim_original_idx = permutation.size () - 3 ;
65206605 const int64_t second_minor_dim_original_idx = permutation.size () - 2 ;
65216606 const int64_t minor_most_dim_original_idx = permutation.size () - 1 ;
@@ -6535,45 +6620,41 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
65356620 .getResult ();
65366621 };
65376622
6538- static constexpr std::array<int , 8 > combine_low_pattern = {0 , 1 , 2 , 3 ,
6539- 8 , 9 , 10 , 11 };
6540- static constexpr std::array<int , 8 > combine_high_pattern = {4 , 5 , 6 , 7 ,
6541- 12 , 13 , 14 , 15 };
6623+ // Generate all patterns using helper functions.
6624+ const HighLowPatterns combine_patterns = CreateCombinePatterns (sublane);
6625+ const HighLowPatterns permute_patterns_stage0 =
6626+ CreateStage0ShufflePatterns (sublane);
6627+ const HighLowPatterns permute_patterns_stage1 =
6628+ CreateStage1ShufflePatterns (sublane);
6629+ const HighLowPatterns permute_patterns_stage2 =
6630+ CreateStage2ShufflePatterns (sublane);
65426631
65436632 auto combine_low = [&](Value lhs_vreg, Value rhs_vreg) {
6544- return shuffle (lhs_vreg, rhs_vreg, combine_low_pattern );
6633+ return shuffle (lhs_vreg, rhs_vreg, combine_patterns. low );
65456634 };
65466635 auto combine_high = [&](Value lhs_vreg, Value rhs_vreg) {
6547- return shuffle (lhs_vreg, rhs_vreg, combine_high_pattern );
6636+ return shuffle (lhs_vreg, rhs_vreg, combine_patterns. high );
65486637 };
65496638
6550- // Shuffle patterns for Stage 1
6551- // Input to shuffle: (combine_low_val, combine_high_val)
6552- // combine_low_val has A0-A3, C0-C3. Indices 0-7 for shuffle.
6553- // combine_high_val has A4-A7, C4-C7. Indices 8-15 for shuffle.
6554- static constexpr std::array<int , 8 > permute_pattern_stage1_low_arr = {
6555- 0 , 1 , 4 , 5 ,
6556- 2 , 3 , 6 , 7 }; // Selects from combine_low_val to make A0A1C0C1A2A3C2C3
6557- static constexpr std::array<int , 8 > permute_pattern_stage1_high_arr = {
6558- 8 , 9 , 12 , 13 , 10 ,
6559- 11 , 14 , 15 }; // Selects from combine_high_val to make A4A5C4C5A6A7C6C7
6560-
6561- // Shuffle patterns for Stage 2
6562- // Input to shuffle: (CL_XY, CH_XY) from Step 2.1 in comments.
6563- // CL_XY has A0A1C0C1B0B1D0D1. Indices 0-7 for shuffle.
6564- // CH_XY has A2A3C2C3B2B3D2D3. Indices 8-15 for shuffle.
6565- static constexpr std::array<int , 8 > permute_pattern_stage2_low_arr = {
6566- 0 , 4 , 2 , 6 , 1 , 5 , 3 , 7 }; // Selects from CL_XY to make A0B0C0D0A1B1C1D1
6567- static constexpr std::array<int , 8 > permute_pattern_stage2_high_arr = {
6568- 8 , 12 , 10 , 14 ,
6569- 9 , 13 , 11 , 15 }; // Selects from CH_XY to make A2B2C2D2A3B3C3D3
6570-
65716639 llvm::SmallVector<int64_t , 4 > original_dst_vregs_dims (
65726640 dst_vregs.dimensions ().begin (), dst_vregs.dimensions ().end ());
65736641
65746642 reshape_to_4d (src_vregs);
65756643 reshape_to_4d (dst_vregs);
65766644
6645+ // Prepare intermediate buffers needed for the algorithm.
6646+ std::vector<Value> stage0_output_vregs (sublane.count );
6647+ std::vector<Value> stage1_output_vregs (sublane.count );
6648+ std::vector<Value> stage2_output_vregs (sublane.count );
6649+ const int num_pairs_each_stage = sublane_tiling / (2 * packing);
6650+ const bool do_stage0 =
6651+ (sublane.count > 8 ) && (sublane_tiling / packing >= 8 );
6652+ const bool do_stage1 = sublane_tiling / packing >= 4 ;
6653+ constexpr int stage0_stride = 4 ;
6654+ constexpr int stage1_stride = 2 ;
6655+ constexpr int stage2_stride = 1 ;
6656+ const int final_combine_stride = sublane_tiling / (2 * packing);
6657+
65776658 // Iterate over the first dim. The algorithm operates on the last 3 dims.
65786659 for (int outer_idx = 0 ; outer_idx < src_vregs.dim (0 ); ++outer_idx) {
65796660 for (int minor_most_dim_slice_idx = 0 ;
@@ -6595,62 +6676,103 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
65956676 second_minor_dim_slice_idx,
65966677 minor_most_dim_slice_idx});
65976678 }
6679+ // STAGE 0! Only needed when sublane_count > 8 and sublane_tiling
6680+ // is at least 8*packing.
6681+ if (do_stage0) {
6682+ for (int i = 0 ; i < num_pairs_each_stage; ++i) {
6683+ int stage0_first_idx =
6684+ (i / stage0_stride) * 2 * stage0_stride +
6685+ (i % stage0_stride);
6686+ int stage0_second_idx = stage0_first_idx + stage0_stride;
6687+
6688+ Value stage0_first_vreg = src_vregs (
6689+ {outer_idx,
6690+ stage0_first_idx * packing + major_subidx +
6691+ sublane_tiling * major_dim_slice_idx,
6692+ second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6693+ Value stage0_second_vreg = src_vregs (
6694+ {outer_idx,
6695+ stage0_second_idx * packing + major_subidx +
6696+ sublane_tiling * major_dim_slice_idx,
6697+ second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6698+
6699+ auto combined_low_val =
6700+ combine_low (stage0_first_vreg, stage0_second_vreg);
6701+ auto combined_high_val =
6702+ combine_high (stage0_first_vreg, stage0_second_vreg);
6703+
6704+ stage0_output_vregs[stage0_first_idx] =
6705+ shuffle (combined_low_val, combined_high_val,
6706+ permute_patterns_stage0.low );
6707+ stage0_output_vregs[stage0_second_idx] =
6708+ shuffle (combined_low_val, combined_high_val,
6709+ permute_patterns_stage0.high );
6710+ }
6711+ }
65986712
6599- // STAGE 1!
6600- std::array<Value, 8 >
6601- stage1_output_vregs; // Stores s1_vregs from comments
6602- const int num_pairs_stage1 =
6603- sublane_tiling / (2 * packing); // Processes pairs of vregs
6604- // (A,C), (B,D), (E,G), (F,H)
6605- const int stage1_stride = num_pairs_stage1 > 1 ? 2 : 1 ;
6606-
6607- for (int i = 0 ; i < num_pairs_stage1; ++i) {
6608- int stage1_first_idx = (i / 2 ) * 4 + (i % 2 );
6609- int stage1_second_idx = stage1_first_idx + stage1_stride;
6610-
6611- Value first_vreg = src_vregs (
6612- {outer_idx,
6613- stage1_first_idx * packing + major_subidx +
6614- sublane_tiling * major_dim_slice_idx,
6615- second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6616- Value second_vreg = src_vregs (
6617- {outer_idx,
6618- stage1_second_idx * packing + major_subidx +
6619- sublane_tiling * major_dim_slice_idx,
6620- second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6621-
6622- auto combined_low_val = combine_low (first_vreg, second_vreg);
6623- auto combined_high_val = combine_high (first_vreg, second_vreg);
6624-
6625- // Initialize for (2*packing, 128) tiling and combine for larger
6626- // tilings.
6627- stage1_output_vregs[stage1_first_idx] =
6628- (sublane_tiling / packing == 2 )
6629- ? first_vreg
6630- : shuffle (combined_low_val, combined_high_val,
6631- permute_pattern_stage1_low_arr);
6632- stage1_output_vregs[stage1_second_idx] =
6633- (sublane_tiling / packing == 2 )
6634- ? second_vreg
6635- : shuffle (combined_low_val, combined_high_val,
6636- permute_pattern_stage1_high_arr);
6713+ // STAGE 1! Only needed when sublane_tiling is at least 4*packing.
6714+ if (do_stage1) {
6715+ for (int i = 0 ; i < num_pairs_each_stage; ++i) {
6716+ int stage1_first_idx =
6717+ (i / stage1_stride) * 2 * stage1_stride +
6718+ (i % stage1_stride);
6719+ int stage1_second_idx = stage1_first_idx + stage1_stride;
6720+
6721+ Value stage1_first_vreg =
6722+ do_stage0
6723+ ? stage0_output_vregs[stage1_first_idx]
6724+ : src_vregs ({outer_idx,
6725+ stage1_first_idx * packing +
6726+ major_subidx +
6727+ sublane_tiling * major_dim_slice_idx,
6728+ second_minor_dim_slice_idx,
6729+ minor_most_dim_slice_idx});
6730+ Value stage1_second_vreg =
6731+ do_stage0
6732+ ? stage0_output_vregs[stage1_second_idx]
6733+ : src_vregs ({outer_idx,
6734+ stage1_second_idx * packing +
6735+ major_subidx +
6736+ sublane_tiling * major_dim_slice_idx,
6737+ second_minor_dim_slice_idx,
6738+ minor_most_dim_slice_idx});
6739+
6740+ auto combined_low_val =
6741+ combine_low (stage1_first_vreg, stage1_second_vreg);
6742+ auto combined_high_val =
6743+ combine_high (stage1_first_vreg, stage1_second_vreg);
6744+
6745+ stage1_output_vregs[stage1_first_idx] =
6746+ shuffle (combined_low_val, combined_high_val,
6747+ permute_patterns_stage1.low );
6748+ stage1_output_vregs[stage1_second_idx] =
6749+ shuffle (combined_low_val, combined_high_val,
6750+ permute_patterns_stage1.high );
6751+ }
66376752 }
66386753
66396754 // STAGE 2!
6640- std::array<Value, 8 >
6641- stage2_output_vregs; // Stores s2_vregs from comments
6642- const int num_pairs_stage2 =
6643- sublane_tiling / (2 * packing); // Processes pairs of vregs
6644- // from stage1_output_vregs
6645- constexpr int stage2_stride = 1 ;
6646-
6647- for (int i = 0 ; i < num_pairs_stage2; ++i) {
6755+ for (int i = 0 ; i < num_pairs_each_stage; ++i) {
66486756 int stage2_first_idx = 2 * i;
66496757 int stage2_second_idx = stage2_first_idx + stage2_stride;
66506758
6651- Value stage2_first_vreg = stage1_output_vregs[stage2_first_idx];
6759+ Value stage2_first_vreg =
6760+ do_stage1
6761+ ? stage1_output_vregs[stage2_first_idx]
6762+ : src_vregs ({outer_idx,
6763+ stage2_first_idx * packing + major_subidx +
6764+ sublane_tiling * major_dim_slice_idx,
6765+ second_minor_dim_slice_idx,
6766+ minor_most_dim_slice_idx});
66526767 Value stage2_second_vreg =
6653- stage1_output_vregs[stage2_second_idx];
6768+ do_stage1
6769+ ? stage1_output_vregs[stage2_second_idx]
6770+ : src_vregs ({outer_idx,
6771+ stage2_second_idx * packing +
6772+ major_subidx +
6773+ sublane_tiling * major_dim_slice_idx,
6774+ second_minor_dim_slice_idx,
6775+ minor_most_dim_slice_idx});
66546776
66556777 auto combined_low_val =
66566778 combine_low (stage2_first_vreg, stage2_second_vreg);
@@ -6659,10 +6781,10 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
66596781
66606782 stage2_output_vregs[stage2_first_idx] =
66616783 shuffle (combined_low_val, combined_high_val,
6662- permute_pattern_stage2_low_arr );
6784+ permute_patterns_stage2. low );
66636785 stage2_output_vregs[stage2_second_idx] =
66646786 shuffle (combined_low_val, combined_high_val,
6665- permute_pattern_stage2_high_arr );
6787+ permute_patterns_stage2. high );
66666788 }
66676789
66686790 // STAGE 3! Combine results from stage 2.
@@ -6672,9 +6794,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
66726794 // s2_vregs[0]..s2_vregs[1] pairing with s2_vregs[2]..s2_vregs[3].
66736795 // For (2, 128) tiling, this corresponds to s2_vregs[0] pairing
66746796 // with s2_vregs[1].
6675- const int num_final_combines = sublane_tiling / (2 * packing);
6676- const int final_combine_stride = sublane_tiling / (2 * packing);
6677- for (int i = 0 ; i < num_final_combines; ++i) {
6797+ for (int i = 0 ; i < num_pairs_each_stage; ++i) {
66786798 Value lhs = stage2_output_vregs[i];
66796799 Value rhs = stage2_output_vregs[i + final_combine_stride];
66806800 auto final_combined_low = combine_low (lhs, rhs);
0 commit comments