Skip to content

Commit cee04a5

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Fix bug in generalizes when dealing with replicated offsets and different tilings
It can lead to incorrect behavior for equivalentTo, too. PiperOrigin-RevId: 798007578
1 parent c911c76 commit cee04a5

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

jaxlib/mosaic/dialect/tpu/layout.cc

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,42 @@ bool VectorLayout::generalizes(
516516
// each other), and we've checked that tilings are different above.
517517
const std::array<int64_t, 2> ishape_tiled_dims =
518518
getImplicitTiledDims(shape, 1);
519-
if (!(tiling_[1] == other.tiling_[1] && tiling_[1] == target_shape[1] &&
520-
offsets_[1].value_or(0) + ishape_tiled_dims[1] <= target_shape[1] &&
521-
offsets_[0].value_or(0) + ishape_tiled_dims[0] <=
522-
std::min(tiling_[0], other.tiling_[0]))) {
519+
CHECK(ishape_tiled_dims == other.getImplicitTiledDims(shape, 1));
520+
if (tiling_[1] != other.tiling_[1] || tiling_[1] != target_shape[1]) {
521+
return false;
522+
}
523+
// Replication imposes stronger conditions when the vreg slice is bigger
524+
// along the given dimension.
525+
// Given a target shape of (8, 128), consider the case of a 32-bit layout.
526+
// The conditions imposed are...
527+
// For (2, 128) tiling with replicated minor (stronger):
528+
// - sublanes 0, 2, 4, 6 are equal
529+
// - sublanes 1, 3, 5, 7 are equal
530+
// For (4, 128) tiling with replicated second minor (stronger):
531+
// - sublanes 0, 1, 2, 3 are equal
532+
// - sublanes 4, 5, 6, 7 are equal
533+
// For (4, 128) tiling with replicated minor (weaker):
534+
// - sublanes 0, 4 are equal
535+
// - sublanes 1, 5 are equal
536+
// - sublanes 2, 6 are equal
537+
// - sublanes 3, 7 are equal
538+
// For (2, 128) tiling with replicated second minor (weaker):
539+
// - sublanes 0, 1 are equal
540+
// - sublanes 2, 3 are equal
541+
// - sublanes 4, 5 are equal
542+
// - sublanes 6, 7 are equal
543+
// Note: Of course, replicated minor also requires elements within the same
544+
// sublane to be equal across lanes.
545+
CHECK(tiling_[0] % other.tiling_[0] == 0 ||
546+
other.tiling_[0] % tiling_[0] == 0);
547+
if (!(!offsets_[0] && tiling_[0] > other.tiling_[0]) &&
548+
!(other.offsets_[0] && *other.offsets_[0] + ishape_tiled_dims[0] <=
549+
std::min(tiling_[0], other.tiling_[0]))) {
550+
return false;
551+
}
552+
if (!(!offsets_[1] && tiling_[0] < other.tiling_[0]) &&
553+
!(other.offsets_[1] &&
554+
*other.offsets_[1] + ishape_tiled_dims[1] <= tiling_[1])) {
523555
return false;
524556
}
525557
}

0 commit comments

Comments
 (0)