Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions jaxlib/mosaic/dialect/tpu/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,49 @@ bool VectorLayout::generalizes(
// each other), and we've checked that tilings are different above.
const std::array<int64_t, 2> ishape_tiled_dims =
getImplicitTiledDims(shape, 1);
if (!(tiling_[1] == other.tiling_[1] && tiling_[1] == target_shape[1] &&
offsets_[1].value_or(0) + ishape_tiled_dims[1] <= target_shape[1] &&
offsets_[0].value_or(0) + ishape_tiled_dims[0] <=
std::min(tiling_[0], other.tiling_[0]))) {
CHECK(ishape_tiled_dims == other.getImplicitTiledDims(shape, 1));
if (tiling_[1] != other.tiling_[1] || tiling_[1] != target_shape[1]) {
return false;
}
// The conditions for replication are more strict when the vreg slice is
// bigger along the replicated dimension.
// Given a target shape of (8, 128), consider the case of a 32-bit layout.
// The conditions imposed are...
// For (2, 128) tiling with replicated minor (more strict):
// - sublanes 0, 2, 4, 6 are equal
// - sublanes 1, 3, 5, 7 are equal
// For (4, 128) tiling with replicated second minor (more strict):
// - sublanes 0, 1, 2, 3 are equal
// - sublanes 4, 5, 6, 7 are equal
// For (4, 128) tiling with replicated minor (less strict):
// - sublanes 0, 4 are equal
// - sublanes 1, 5 are equal
// - sublanes 2, 6 are equal
// - sublanes 3, 7 are equal
// For (2, 128) tiling with replicated second minor (less strict):
// - sublanes 0, 1 are equal
// - sublanes 2, 3 are equal
// - sublanes 4, 5 are equal
// - sublanes 6, 7 are equal
// Of course, replicated minor also requires elements within the same
// sublane to be equal across lanes.
// Note how, for example, the condition for (4, 128) replicated 2nd minor
// implies the condition for (2, 128) replicated 2nd minor.
for (const int i : {0, 1}) {
CHECK(!offsets_[i] || offsets_[i] == other.offsets_[i]);
if (!offsets_[i] &&
vregSlice(target_shape)[i] % other.vregSlice(target_shape)[i] == 0) {
// Okay: Data is replicated along this dimension and this layout's
// replication condition implies the replication condition for other's
// tiling (along this dimension).
} else if (other.offsets_[i] &&
*other.offsets_[i] + ishape_tiled_dims[i] <=
std::min(tiling_[i], other.tiling_[i])) {
// Okay: Data fits within the first tile of both layouts.
} else {
return false;
}
}
}
return true;
}
Expand Down
Loading