Skip to content

Commit 2611fd2

Browse files
yueshengysGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Support second/third minor transpose when sublane_count is not 8.
PiperOrigin-RevId: 833923508
1 parent 2912a1c commit 2611fd2

File tree

1 file changed

+215
-95
lines changed

1 file changed

+215
-95
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 215 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6302,6 +6302,92 @@ LogicalResult tpu_vector_store_rule(RewriteContext &ctx, Operation &op,
63026302
store_op.getMask());
63036303
}
63046304

6305+
namespace {
6306+
6307+
// Structure to hold sublane count and derived values.
6308+
struct Sublane {
6309+
const int count;
6310+
const int half;
6311+
const int quarter;
6312+
const int octa;
6313+
explicit Sublane(int sc)
6314+
: count(sc), half(sc / 2), quarter(sc / 4), octa(sc / 8) {}
6315+
};
6316+
6317+
// Structure to hold pairs of low and high pattern vectors.
6318+
struct HighLowPatterns {
6319+
std::vector<int> low;
6320+
std::vector<int> high;
6321+
explicit HighLowPatterns(int size) : low(size), high(size) {}
6322+
};
6323+
6324+
// Helper to create combine patterns.
6325+
HighLowPatterns CreateCombinePatterns(const Sublane& sublane) {
6326+
HighLowPatterns patterns(sublane.count);
6327+
// For sublane.count = 8, low = {0, 1, 2, 3, 8, 9, 10, 11}, high = {4, 5, 6,
6328+
// 7, 12, 13, 14, 15}.
6329+
for (int i = 0; i < sublane.half; ++i) {
6330+
patterns.low[i] = i;
6331+
patterns.low[i + sublane.half] = i + sublane.count;
6332+
}
6333+
absl::c_transform(patterns.low, patterns.high.begin(),
6334+
[sublane](int value) { return value + sublane.half; });
6335+
return patterns;
6336+
}
6337+
6338+
// Helper to create shuffle patterns for Stage 0.
6339+
HighLowPatterns CreateStage0ShufflePatterns(const Sublane& sublane) {
6340+
HighLowPatterns patterns(sublane.count);
6341+
for (int i = 0; i < sublane.quarter; ++i) {
6342+
patterns.low[i] = i;
6343+
patterns.low[i + sublane.quarter] = i + sublane.half;
6344+
patterns.low[i + sublane.half] = i + sublane.quarter;
6345+
patterns.low[i + sublane.half + sublane.quarter] =
6346+
i + sublane.half + sublane.quarter;
6347+
}
6348+
absl::c_transform(patterns.low, patterns.high.begin(),
6349+
[sublane](int value) { return value + sublane.count; });
6350+
return patterns;
6351+
}
6352+
6353+
// Helper to create shuffle patterns for Stage 1.
6354+
HighLowPatterns CreateStage1ShufflePatterns(const Sublane& sublane) {
6355+
HighLowPatterns patterns(sublane.count);
6356+
// For sublane.count = 8, low = {0, 1, 4, 5, 2, 3, 6, 7}, high = {8, 9, 12,
6357+
// 13, 10, 11, 14, 15}.
6358+
for (int i = 0; i < sublane.octa; ++i) {
6359+
patterns.low[4 * i] = 4 * i;
6360+
patterns.low[4 * i + 1] = 4 * i + 1;
6361+
patterns.low[4 * i + 2] = 4 * i + sublane.half;
6362+
patterns.low[4 * i + 3] = 4 * i + sublane.half + 1;
6363+
patterns.low[4 * i + sublane.half] = 4 * i + 2;
6364+
patterns.low[4 * i + sublane.half + 1] = 4 * i + 3;
6365+
patterns.low[4 * i + sublane.half + 2] = 4 * i + sublane.half + 2;
6366+
patterns.low[4 * i + sublane.half + 3] = 4 * i + sublane.half + 3;
6367+
}
6368+
absl::c_transform(patterns.low, patterns.high.begin(),
6369+
[sublane](int value) { return value + sublane.count; });
6370+
return patterns;
6371+
}
6372+
6373+
// Helper to create shuffle patterns for Stage 2.
6374+
HighLowPatterns CreateStage2ShufflePatterns(const Sublane& sublane) {
6375+
HighLowPatterns patterns(sublane.count);
6376+
// For sublane.count = 8, low = {0, 4, 2, 6, 1, 5, 3, 7}, high = {8, 12, 10,
6377+
// 14, 9, 13, 11, 15}.
6378+
for (int i = 0; i < sublane.quarter; ++i) {
6379+
patterns.low[2 * i] = 2 * i;
6380+
patterns.low[2 * i + 1] = 2 * i + sublane.half;
6381+
patterns.low[2 * i + sublane.half] = 2 * i + 1;
6382+
patterns.low[2 * i + sublane.half + 1] = 2 * i + sublane.half + 1;
6383+
}
6384+
absl::c_transform(patterns.low, patterns.high.begin(),
6385+
[sublane](int value) { return value + sublane.count; });
6386+
return patterns;
6387+
}
6388+
6389+
} // namespace
6390+
63056391
LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
63066392
const ArrayRef<Layout> layouts_in,
63076393
const ArrayRef<Layout> layouts_out) {
@@ -6351,14 +6437,17 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
63516437
// This is a 3 stage algorithm that uses combinations and shuffles to do a
63526438
// transposition of an 8x8 or 4x8 or 2x8 block of sublanes. We use the
63536439
// 32-bit 8x8 case as an example. For 4x8 and 2x8, we just need to run fewer
6354-
// rounds of the algorithm. For packed dtypes, we will essentially working
6355-
// on (8*packing)x8 blocks, and repeat the 8x8 algorithm `packing` times,
6356-
// the first 8x8 block contains the `packing * i`th input vregs, the second
6357-
// 8x8 block contains the `packing * i + 1`th input vregs, etc. Take 16-bit
6358-
// as an example, we have 16 vregs in total, they are [V0, V1,..., V15]. We
6359-
// view [V0, V2,...,V14] as the first 8x8 block, and [V1, V3,...,V15] as the
6360-
// second 8x8 block. We also need to do an extra unpacking and repacking
6361-
// step for packed dtypes after the 3 stage algorithm.
6440+
// rounds of the algorithm. This algorithm can also be generalized to
6441+
// support arbitrary sublane count. For sublane count larger than 8, just
6442+
// need to run more rounds. For packed dtypes, we will essentially be
6443+
// working on (8*packing)x8 blocks, and repeat the 8x8 algorithm `packing`
6444+
// times, the first 8x8 block contains the `packing * i`th input vregs, the
6445+
// second 8x8 block contains the `packing * i + 1`th input vregs, etc. Take
6446+
// 16-bit as an example, we have 16 vregs in total, they are [V0, V1,...,
6447+
// V15]. We view [V0, V2,...,V14] as the first 8x8 block, and [V1,
6448+
// V3,...,V15] as the second 8x8 block. We also need to do an extra
6449+
// unpacking and repacking step for packed dtypes after the 3 stage
6450+
// algorithm.
63626451

63636452
// In the following algorithm description, A, B, ..., H represent 8 distinct
63646453
// input vregs that form an 8x8 block of data to be transposed. In our
@@ -6458,11 +6547,6 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
64586547
if (layout_in.offsets() != LayoutOffsets{0, 0}) {
64596548
return transpose_op.emitOpError("Not implemented: Layout with offset.");
64606549
}
6461-
// TODO(b/456173864): Relax this constraint.
6462-
if (ctx.target_shape[0] != 8) {
6463-
return transpose_op.emitOpError(
6464-
"Not implemented: Major-second-minor transpose expects 8 sublanes.");
6465-
}
64666550

64676551
{
64686552
// Transpose 4th+ minors if applicable.
@@ -6475,6 +6559,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
64756559

64766560
const int packing = layout_in.packing();
64776561
const int bitwidth = layout_in.bitwidth();
6562+
const Sublane sublane(ctx.target_shape[0]);
64786563
const int64_t sublane_tiling = layout_in.tiling()[0];
64796564

64806565
int64_t src_vregs_dim = src_vregs.num_dimensions();
@@ -6513,9 +6598,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
65136598
// we cannot just resolve that in our outer loop is because of the nature
65146599
// of a transpose - this dim value goes unmultiplied into the output vregs.
65156600
// effectively, our indexing:
6516-
// {major_dim_slice_idx * sublane_count, second_minor_dim_slice_idx,
6601+
// {major_dim_slice_idx * sublane_tiling, second_minor_dim_slice_idx,
65176602
// minor_most_dim_slice_idx} becomes {second_minor_dim_slice_idx *
6518-
// sublane_count, major_dim_slice_idx, minor_most_dim_slice_idx}
6603+
// sublane_tiling, major_dim_slice_idx, minor_most_dim_slice_idx}
65196604
const int64_t major_dim_original_idx = permutation.size() - 3;
65206605
const int64_t second_minor_dim_original_idx = permutation.size() - 2;
65216606
const int64_t minor_most_dim_original_idx = permutation.size() - 1;
@@ -6535,45 +6620,41 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
65356620
.getResult();
65366621
};
65376622

6538-
static constexpr std::array<int, 8> combine_low_pattern = {0, 1, 2, 3,
6539-
8, 9, 10, 11};
6540-
static constexpr std::array<int, 8> combine_high_pattern = {4, 5, 6, 7,
6541-
12, 13, 14, 15};
6623+
// Generate all patterns using helper functions.
6624+
const HighLowPatterns combine_patterns = CreateCombinePatterns(sublane);
6625+
const HighLowPatterns permute_patterns_stage0 =
6626+
CreateStage0ShufflePatterns(sublane);
6627+
const HighLowPatterns permute_patterns_stage1 =
6628+
CreateStage1ShufflePatterns(sublane);
6629+
const HighLowPatterns permute_patterns_stage2 =
6630+
CreateStage2ShufflePatterns(sublane);
65426631

65436632
auto combine_low = [&](Value lhs_vreg, Value rhs_vreg) {
6544-
return shuffle(lhs_vreg, rhs_vreg, combine_low_pattern);
6633+
return shuffle(lhs_vreg, rhs_vreg, combine_patterns.low);
65456634
};
65466635
auto combine_high = [&](Value lhs_vreg, Value rhs_vreg) {
6547-
return shuffle(lhs_vreg, rhs_vreg, combine_high_pattern);
6636+
return shuffle(lhs_vreg, rhs_vreg, combine_patterns.high);
65486637
};
65496638

6550-
// Shuffle patterns for Stage 1
6551-
// Input to shuffle: (combine_low_val, combine_high_val)
6552-
// combine_low_val has A0-A3, C0-C3. Indices 0-7 for shuffle.
6553-
// combine_high_val has A4-A7, C4-C7. Indices 8-15 for shuffle.
6554-
static constexpr std::array<int, 8> permute_pattern_stage1_low_arr = {
6555-
0, 1, 4, 5,
6556-
2, 3, 6, 7}; // Selects from combine_low_val to make A0A1C0C1A2A3C2C3
6557-
static constexpr std::array<int, 8> permute_pattern_stage1_high_arr = {
6558-
8, 9, 12, 13, 10,
6559-
11, 14, 15}; // Selects from combine_high_val to make A4A5C4C5A6A7C6C7
6560-
6561-
// Shuffle patterns for Stage 2
6562-
// Input to shuffle: (CL_XY, CH_XY) from Step 2.1 in comments.
6563-
// CL_XY has A0A1C0C1B0B1D0D1. Indices 0-7 for shuffle.
6564-
// CH_XY has A2A3C2C3B2B3D2D3. Indices 8-15 for shuffle.
6565-
static constexpr std::array<int, 8> permute_pattern_stage2_low_arr = {
6566-
0, 4, 2, 6, 1, 5, 3, 7}; // Selects from CL_XY to make A0B0C0D0A1B1C1D1
6567-
static constexpr std::array<int, 8> permute_pattern_stage2_high_arr = {
6568-
8, 12, 10, 14,
6569-
9, 13, 11, 15}; // Selects from CH_XY to make A2B2C2D2A3B3C3D3
6570-
65716639
llvm::SmallVector<int64_t, 4> original_dst_vregs_dims(
65726640
dst_vregs.dimensions().begin(), dst_vregs.dimensions().end());
65736641

65746642
reshape_to_4d(src_vregs);
65756643
reshape_to_4d(dst_vregs);
65766644

6645+
// Prepare intermediate buffers needed for the algorithm.
6646+
std::vector<Value> stage0_output_vregs(sublane.count);
6647+
std::vector<Value> stage1_output_vregs(sublane.count);
6648+
std::vector<Value> stage2_output_vregs(sublane.count);
6649+
const int num_pairs_each_stage = sublane_tiling / (2 * packing);
6650+
const bool do_stage0 =
6651+
(sublane.count > 8) && (sublane_tiling / packing >= 8);
6652+
const bool do_stage1 = sublane_tiling / packing >= 4;
6653+
constexpr int stage0_stride = 4;
6654+
constexpr int stage1_stride = 2;
6655+
constexpr int stage2_stride = 1;
6656+
const int final_combine_stride = sublane_tiling / (2 * packing);
6657+
65776658
// Iterate over the first dim. The algorithm operates on the last 3 dims.
65786659
for (int outer_idx = 0; outer_idx < src_vregs.dim(0); ++outer_idx) {
65796660
for (int minor_most_dim_slice_idx = 0;
@@ -6595,62 +6676,103 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
65956676
second_minor_dim_slice_idx,
65966677
minor_most_dim_slice_idx});
65976678
}
6679+
// STAGE 0! Only needed when sublane_count > 8 and sublane_tiling
6680+
// is at least 8*packing.
6681+
if (do_stage0) {
6682+
for (int i = 0; i < num_pairs_each_stage; ++i) {
6683+
int stage0_first_idx =
6684+
(i / stage0_stride) * 2 * stage0_stride +
6685+
(i % stage0_stride);
6686+
int stage0_second_idx = stage0_first_idx + stage0_stride;
6687+
6688+
Value stage0_first_vreg = src_vregs(
6689+
{outer_idx,
6690+
stage0_first_idx * packing + major_subidx +
6691+
sublane_tiling * major_dim_slice_idx,
6692+
second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6693+
Value stage0_second_vreg = src_vregs(
6694+
{outer_idx,
6695+
stage0_second_idx * packing + major_subidx +
6696+
sublane_tiling * major_dim_slice_idx,
6697+
second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6698+
6699+
auto combined_low_val =
6700+
combine_low(stage0_first_vreg, stage0_second_vreg);
6701+
auto combined_high_val =
6702+
combine_high(stage0_first_vreg, stage0_second_vreg);
6703+
6704+
stage0_output_vregs[stage0_first_idx] =
6705+
shuffle(combined_low_val, combined_high_val,
6706+
permute_patterns_stage0.low);
6707+
stage0_output_vregs[stage0_second_idx] =
6708+
shuffle(combined_low_val, combined_high_val,
6709+
permute_patterns_stage0.high);
6710+
}
6711+
}
65986712

6599-
// STAGE 1!
6600-
std::array<Value, 8>
6601-
stage1_output_vregs; // Stores s1_vregs from comments
6602-
const int num_pairs_stage1 =
6603-
sublane_tiling / (2 * packing); // Processes pairs of vregs
6604-
// (A,C), (B,D), (E,G), (F,H)
6605-
const int stage1_stride = num_pairs_stage1 > 1 ? 2 : 1;
6606-
6607-
for (int i = 0; i < num_pairs_stage1; ++i) {
6608-
int stage1_first_idx = (i / 2) * 4 + (i % 2);
6609-
int stage1_second_idx = stage1_first_idx + stage1_stride;
6610-
6611-
Value first_vreg = src_vregs(
6612-
{outer_idx,
6613-
stage1_first_idx * packing + major_subidx +
6614-
sublane_tiling * major_dim_slice_idx,
6615-
second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6616-
Value second_vreg = src_vregs(
6617-
{outer_idx,
6618-
stage1_second_idx * packing + major_subidx +
6619-
sublane_tiling * major_dim_slice_idx,
6620-
second_minor_dim_slice_idx, minor_most_dim_slice_idx});
6621-
6622-
auto combined_low_val = combine_low(first_vreg, second_vreg);
6623-
auto combined_high_val = combine_high(first_vreg, second_vreg);
6624-
6625-
// Initialize for (2*packing, 128) tiling and combine for larger
6626-
// tilings.
6627-
stage1_output_vregs[stage1_first_idx] =
6628-
(sublane_tiling / packing == 2)
6629-
? first_vreg
6630-
: shuffle(combined_low_val, combined_high_val,
6631-
permute_pattern_stage1_low_arr);
6632-
stage1_output_vregs[stage1_second_idx] =
6633-
(sublane_tiling / packing == 2)
6634-
? second_vreg
6635-
: shuffle(combined_low_val, combined_high_val,
6636-
permute_pattern_stage1_high_arr);
6713+
// STAGE 1! Only needed when sublane_tiling is at least 4*packing.
6714+
if (do_stage1) {
6715+
for (int i = 0; i < num_pairs_each_stage; ++i) {
6716+
int stage1_first_idx =
6717+
(i / stage1_stride) * 2 * stage1_stride +
6718+
(i % stage1_stride);
6719+
int stage1_second_idx = stage1_first_idx + stage1_stride;
6720+
6721+
Value stage1_first_vreg =
6722+
do_stage0
6723+
? stage0_output_vregs[stage1_first_idx]
6724+
: src_vregs({outer_idx,
6725+
stage1_first_idx * packing +
6726+
major_subidx +
6727+
sublane_tiling * major_dim_slice_idx,
6728+
second_minor_dim_slice_idx,
6729+
minor_most_dim_slice_idx});
6730+
Value stage1_second_vreg =
6731+
do_stage0
6732+
? stage0_output_vregs[stage1_second_idx]
6733+
: src_vregs({outer_idx,
6734+
stage1_second_idx * packing +
6735+
major_subidx +
6736+
sublane_tiling * major_dim_slice_idx,
6737+
second_minor_dim_slice_idx,
6738+
minor_most_dim_slice_idx});
6739+
6740+
auto combined_low_val =
6741+
combine_low(stage1_first_vreg, stage1_second_vreg);
6742+
auto combined_high_val =
6743+
combine_high(stage1_first_vreg, stage1_second_vreg);
6744+
6745+
stage1_output_vregs[stage1_first_idx] =
6746+
shuffle(combined_low_val, combined_high_val,
6747+
permute_patterns_stage1.low);
6748+
stage1_output_vregs[stage1_second_idx] =
6749+
shuffle(combined_low_val, combined_high_val,
6750+
permute_patterns_stage1.high);
6751+
}
66376752
}
66386753

66396754
// STAGE 2!
6640-
std::array<Value, 8>
6641-
stage2_output_vregs; // Stores s2_vregs from comments
6642-
const int num_pairs_stage2 =
6643-
sublane_tiling / (2 * packing); // Processes pairs of vregs
6644-
// from stage1_output_vregs
6645-
constexpr int stage2_stride = 1;
6646-
6647-
for (int i = 0; i < num_pairs_stage2; ++i) {
6755+
for (int i = 0; i < num_pairs_each_stage; ++i) {
66486756
int stage2_first_idx = 2 * i;
66496757
int stage2_second_idx = stage2_first_idx + stage2_stride;
66506758

6651-
Value stage2_first_vreg = stage1_output_vregs[stage2_first_idx];
6759+
Value stage2_first_vreg =
6760+
do_stage1
6761+
? stage1_output_vregs[stage2_first_idx]
6762+
: src_vregs({outer_idx,
6763+
stage2_first_idx * packing + major_subidx +
6764+
sublane_tiling * major_dim_slice_idx,
6765+
second_minor_dim_slice_idx,
6766+
minor_most_dim_slice_idx});
66526767
Value stage2_second_vreg =
6653-
stage1_output_vregs[stage2_second_idx];
6768+
do_stage1
6769+
? stage1_output_vregs[stage2_second_idx]
6770+
: src_vregs({outer_idx,
6771+
stage2_second_idx * packing +
6772+
major_subidx +
6773+
sublane_tiling * major_dim_slice_idx,
6774+
second_minor_dim_slice_idx,
6775+
minor_most_dim_slice_idx});
66546776

66556777
auto combined_low_val =
66566778
combine_low(stage2_first_vreg, stage2_second_vreg);
@@ -6659,10 +6781,10 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
66596781

66606782
stage2_output_vregs[stage2_first_idx] =
66616783
shuffle(combined_low_val, combined_high_val,
6662-
permute_pattern_stage2_low_arr);
6784+
permute_patterns_stage2.low);
66636785
stage2_output_vregs[stage2_second_idx] =
66646786
shuffle(combined_low_val, combined_high_val,
6665-
permute_pattern_stage2_high_arr);
6787+
permute_patterns_stage2.high);
66666788
}
66676789

66686790
// STAGE 3! Combine results from stage 2.
@@ -6672,9 +6794,7 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
66726794
// s2_vregs[0]..s2_vregs[1] pairing with s2_vregs[2]..s2_vregs[3].
66736795
// For (2, 128) tiling, this corresponds to s2_vregs[0] pairing
66746796
// with s2_vregs[1].
6675-
const int num_final_combines = sublane_tiling / (2 * packing);
6676-
const int final_combine_stride = sublane_tiling / (2 * packing);
6677-
for (int i = 0; i < num_final_combines; ++i) {
6797+
for (int i = 0; i < num_pairs_each_stage; ++i) {
66786798
Value lhs = stage2_output_vregs[i];
66796799
Value rhs = stage2_output_vregs[i + final_combine_stride];
66806800
auto final_combined_low = combine_low(lhs, rhs);

0 commit comments

Comments
 (0)