Skip to content

Commit c3cbd25

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Fix bug in generalizes when dealing with replicated offsets and different tilings
It can lead to incorrect behavior for equivalentTo, too. PiperOrigin-RevId: 798007578
1 parent 4257c62 commit c3cbd25

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

jaxlib/mosaic/dialect/tpu/layout.cc

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,12 +480,49 @@ bool VectorLayout::generalizes(
480480
// each other), and we've checked that tilings are different above.
481481
const std::array<int64_t, 2> ishape_tiled_dims =
482482
getImplicitTiledDims(shape, 1);
483-
if (!(tiling_[1] == other.tiling_[1] && tiling_[1] == target_shape[1] &&
484-
offsets_[1].value_or(0) + ishape_tiled_dims[1] <= target_shape[1] &&
485-
offsets_[0].value_or(0) + ishape_tiled_dims[0] <=
486-
std::min(tiling_[0], other.tiling_[0]))) {
483+
CHECK(ishape_tiled_dims == other.getImplicitTiledDims(shape, 1));
484+
if (tiling_[1] != other.tiling_[1] || tiling_[1] != target_shape[1]) {
487485
return false;
488486
}
487+
// The conditions for replication are more strict when the vreg slice is
488+
// bigger along the replicated dimension.
489+
// Given a target shape of (8, 128), consider the case of a 32-bit layout.
490+
// The conditions imposed are...
491+
// For (2, 128) tiling with replicated minor (more strict):
492+
// - sublanes 0, 2, 4, 6 are equal
493+
// - sublanes 1, 3, 5, 7 are equal
494+
// For (4, 128) tiling with replicated second minor (more strict):
495+
// - sublanes 0, 1, 2, 3 are equal
496+
// - sublanes 4, 5, 6, 7 are equal
497+
// For (4, 128) tiling with replicated minor (less strict):
498+
// - sublanes 0, 4 are equal
499+
// - sublanes 1, 5 are equal
500+
// - sublanes 2, 6 are equal
501+
// - sublanes 3, 7 are equal
502+
// For (2, 128) tiling with replicated second minor (less strict):
503+
// - sublanes 0, 1 are equal
504+
// - sublanes 2, 3 are equal
505+
// - sublanes 4, 5 are equal
506+
// - sublanes 6, 7 are equal
507+
// Of course, replicated minor also requires elements within the same
508+
// sublane to be equal across lanes.
509+
// Note how, for example, the condition for (4, 128) replicated 2nd minor
510+
// implies the condition for (2, 128) replicated 2nd minor.
511+
for (const int i : {0, 1}) {
512+
CHECK(!offsets_[i] || offsets_[i] == other.offsets_[i]);
513+
if (!offsets_[i] &&
514+
vregSlice(target_shape)[i] % other.vregSlice(target_shape)[i] == 0) {
515+
// Okay: Data is replicated along this dimension and this layout's
516+
// replication condition implies the replication condition for other's
517+
// tiling (along this dimension).
518+
} else if (other.offsets_[i] &&
519+
*other.offsets_[i] + ishape_tiled_dims[i] <=
520+
std::min(tiling_[i], other.tiling_[i])) {
521+
// Okay: Data fits within the first tile of both layouts.
522+
} else {
523+
return false;
524+
}
525+
}
489526
}
490527
return true;
491528
}

0 commit comments

Comments
 (0)