From 89305e7f39a1b27e623040859f1550ae4f446937 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 17 Nov 2025 03:25:21 -0800 Subject: [PATCH] [pallas:sc] Skipped a few more tests under TC tiling due to verification errors PiperOrigin-RevId: 833257620 --- jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 3 ++- tests/pallas/tpu_sparsecore_pallas_test.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 9d9fcf624a40..e5cffcc407d1 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -59,7 +59,8 @@ std::pair mightCommunicateBetweenChips(Operation *op); std::unique_ptr> createInferMemRefLayoutPass( int hardware_generation, std::array target_shape, - const TpuTilingFlags& tpu_tiling_flags, bool align = true); + const TpuTilingFlags& tpu_tiling_flags, bool align = true, + bool infer_kernel_arguments = true); #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 5f5c1726ffba..6e3e1a3b215d 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -296,7 +296,13 @@ def kernel(x_ref, o_ref): @hp.given(hps.data()) def test_block_spec_untiled_slicing(self, data): if not self.USE_TC_TILING: - self.skipTest("Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')") + self.skipTest( + "Test uncovers a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')" + ) + else: + self.skipTest( + "Test uncovers a bug: @reproduce_failure('6.80.0', b'AAEAAQAAAAA=')" + ) slice_shape = data.draw( hps.lists( hps.integers(1, 3), min_size=(1 + self.USE_TC_TILING), max_size=4 @@ -1065,7 +1071,7 @@ def scoped_kernel(scratch_ref): pl.run_scoped( scoped_kernel, plsc.MemoryRef( - x.shape, x_ref.dtype, memory_space=pltpu.VMEM, tiling=[(2, 1)] + x.shape, x_ref.dtype, memory_space=pltpu.VMEM, tiling=[(1, 8)] ), )