Skip to content

Commit 9cf57af

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Fix bug in generalizes when dealing with replicated offsets and different tilings
It can lead to incorrect behavior for equivalentTo, too. PiperOrigin-RevId: 798007578
1 parent b9291e5 commit 9cf57af

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

jaxlib/mosaic/dialect/tpu/layout.cc

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,12 +516,49 @@ bool VectorLayout::generalizes(
516516
// each other), and we've checked that tilings are different above.
517517
const std::array<int64_t, 2> ishape_tiled_dims =
518518
getImplicitTiledDims(shape, 1);
519-
if (!(tiling_[1] == other.tiling_[1] && tiling_[1] == target_shape[1] &&
520-
offsets_[1].value_or(0) + ishape_tiled_dims[1] <= target_shape[1] &&
521-
offsets_[0].value_or(0) + ishape_tiled_dims[0] <=
522-
std::min(tiling_[0], other.tiling_[0]))) {
519+
CHECK(ishape_tiled_dims == other.getImplicitTiledDims(shape, 1));
520+
if (tiling_[1] != other.tiling_[1] || tiling_[1] != target_shape[1]) {
523521
return false;
524522
}
523+
// The conditions for replication are more strict when the vreg slice is
524+
// bigger along the replicated dimension.
525+
// Given a target shape of (8, 128), consider the case of a 32-bit layout.
526+
// The conditions imposed are...
527+
// For (2, 128) tiling with replicated minor (more strict):
528+
// - sublanes 0, 2, 4, 6 are equal
529+
// - sublanes 1, 3, 5, 7 are equal
530+
// For (4, 128) tiling with replicated second minor (more strict):
531+
// - sublanes 0, 1, 2, 3 are equal
532+
// - sublanes 4, 5, 6, 7 are equal
533+
// For (4, 128) tiling with replicated minor (less strict):
534+
// - sublanes 0, 4 are equal
535+
// - sublanes 1, 5 are equal
536+
// - sublanes 2, 6 are equal
537+
// - sublanes 3, 7 are equal
538+
// For (2, 128) tiling with replicated second minor (less strict):
539+
// - sublanes 0, 1 are equal
540+
// - sublanes 2, 3 are equal
541+
// - sublanes 4, 5 are equal
542+
// - sublanes 6, 7 are equal
543+
// Of course, replicated minor also requires elements within the same
544+
// sublane to be equal across lanes.
545+
// Note how, for example, the condition for (4, 128) replicated 2nd minor
546+
// implies the condition for (2, 128) replicated 2nd minor.
547+
for (const int i : {0, 1}) {
548+
CHECK(!offsets_[i] || offsets_[i] == other.offsets_[i]);
549+
if (!offsets_[i] &&
550+
vregSlice(target_shape)[i] % other.vregSlice(target_shape)[i] == 0) {
551+
// Okay: Data is replicated along this dimension and this layout's
552+
// replication condition implies the replication condition for other's
553+
// tiling (along this dimension).
554+
} else if (other.offsets_[i] &&
555+
*other.offsets_[i] + ishape_tiled_dims[i] <=
556+
std::min(tiling_[i], other.tiling_[i])) {
557+
// Okay: Data fits within the first tile of both layouts.
558+
} else {
559+
return false;
560+
}
561+
}
525562
}
526563
return true;
527564
}

0 commit comments

Comments
 (0)