diff --git a/jaxlib/mosaic/dialect/tpu/layout.cc b/jaxlib/mosaic/dialect/tpu/layout.cc index 28f975976748..c58fb97d7b62 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.cc +++ b/jaxlib/mosaic/dialect/tpu/layout.cc @@ -480,12 +480,49 @@ bool VectorLayout::generalizes( // each other), and we've checked that tilings are different above. const std::array ishape_tiled_dims = getImplicitTiledDims(shape, 1); - if (!(tiling_[1] == other.tiling_[1] && tiling_[1] == target_shape[1] && - offsets_[1].value_or(0) + ishape_tiled_dims[1] <= target_shape[1] && - offsets_[0].value_or(0) + ishape_tiled_dims[0] <= - std::min(tiling_[0], other.tiling_[0]))) { + CHECK(ishape_tiled_dims == other.getImplicitTiledDims(shape, 1)); + if (tiling_[1] != other.tiling_[1] || tiling_[1] != target_shape[1]) { return false; } + // The conditions for replication are more strict when the vreg slice is + // bigger along the replicated dimension. + // Given a target shape of (8, 128), consider the case of a 32-bit layout. + // The conditions imposed are... + // For (2, 128) tiling with replicated minor (more strict): + // - sublanes 0, 2, 4, 6 are equal + // - sublanes 1, 3, 5, 7 are equal + // For (4, 128) tiling with replicated second minor (more strict): + // - sublanes 0, 1, 2, 3 are equal + // - sublanes 4, 5, 6, 7 are equal + // For (4, 128) tiling with replicated minor (less strict): + // - sublanes 0, 4 are equal + // - sublanes 1, 5 are equal + // - sublanes 2, 6 are equal + // - sublanes 3, 7 are equal + // For (2, 128) tiling with replicated second minor (less strict): + // - sublanes 0, 1 are equal + // - sublanes 2, 3 are equal + // - sublanes 4, 5 are equal + // - sublanes 6, 7 are equal + // Of course, replicated minor also requires elements within the same + // sublane to be equal across lanes. + // Note how, for example, the condition for (4, 128) replicated 2nd minor + // implies the condition for (2, 128) replicated 2nd minor. + for (const int i : {0, 1}) { + CHECK(!offsets_[i] || offsets_[i] == other.offsets_[i]); + if (!offsets_[i] && + vregSlice(target_shape)[i] % other.vregSlice(target_shape)[i] == 0) { + // Okay: Data is replicated along this dimension and this layout's + // replication condition implies the replication condition for other's + // tiling (along this dimension). + } else if (other.offsets_[i] && + *other.offsets_[i] + ishape_tiled_dims[i] <= + std::min(tiling_[i], other.tiling_[i])) { + // Okay: Data fits within the first tile of both layouts. + } else { + return false; + } + } } return true; }