Skip to content

Commit dc7d56e

Browse files
yueshengysGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Allow using replicated lane offset for intermediate safe offsets in mask relayout since relayout like {0,*},(8,128) -> {0,*},(4,128) has been just supported.
PiperOrigin-RevId: 826555654
1 parent 7a007ea commit dc7d56e

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8434,7 +8434,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
84348434
retiled.Each([&](absl::Span<const int64_t> dst_idx, Value* const vreg) {
84358435
// Recall that this is a one-to-many vregs relayout. Each destination vreg
84368436
// will hold a different part of the source data.
8437-
// For example, a (8, 128) -> (4, 128) retiling with replicated 2nd minor:
8437+
// For example, a (8, 128) -> (4, 128) retiling with replicated minor:
84388438
// - Destination vreg 0 comes from the first 4 sublanes of source vreg 0,
84398439
// with gather pattern [0, 1, 2, 3, 0, 1, 2, 3]
84408440
// - Destination vreg 1 comes from the last 4 sublanes of source vreg 0,

jaxlib/mosaic/dialect/tpu/transforms/relayout_insertion.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "mlir/IR/Visitors.h"
3232
#include "mlir/Pass/Pass.h"
3333
#include "mlir/Support/LLVM.h"
34+
#include "mlir/Support/WalkResult.h"
3435
#include "jaxlib/mosaic/dialect/tpu/layout.h"
3536
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
3637
#include "jaxlib/mosaic/dialect/tpu/util.h"
@@ -138,9 +139,8 @@ FailureOr<TypedValue<VectorType>> relayout(
138139
auto safe_offsets = LayoutOffsets{
139140
src.offsets()[0].has_value() ? *src.offsets()[0] % safe_vreg_slice[0]
140141
: LayoutOffset(),
141-
src.offsets()[1].has_value()
142-
? *src.offsets()[1] % safe_vreg_slice[1]
143-
: 0, // TODO(b/452689987): change to LayoutOffset() after resolved.
142+
src.offsets()[1].has_value() ? *src.offsets()[1] % safe_vreg_slice[1]
143+
: LayoutOffset(),
144144
};
145145
auto safe_src = VectorLayout(src.bitwidth(), safe_offsets, safe_tiling,
146146
dst.implicit_dim());
@@ -281,4 +281,4 @@ std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
281281
target_shape);
282282
}
283283

284-
} // namespace mlir::tpu
284+
} // namespace mlir::tpu

tests/pallas/tpu_ops_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,8 @@ def else_0():
668668
self.assertEqual(output, 0)
669669

670670
def test_retiling_with_replicated_lane(self):
671-
self.skipTest("TODO(b/452689987)")
671+
if not jtu.if_cloud_tpu_at_least(2025, 11, 5):
672+
self.skipTest("Test requires libtpu from 2025/11/5 or later")
672673
shape = (32, 1)
673674
broadcast_shape = (32, 256)
674675

0 commit comments

Comments
 (0)