diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 9d9fcf624a40..e5cffcc407d1 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -59,7 +59,8 @@ std::pair mightCommunicateBetweenChips(Operation *op); std::unique_ptr> createInferMemRefLayoutPass( int hardware_generation, std::array target_shape, - const TpuTilingFlags& tpu_tiling_flags, bool align = true); + const TpuTilingFlags& tpu_tiling_flags, bool align = true, + bool infer_kernel_arguments = true); #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" diff --git a/tests/pallas/tpu_sparsecore_pallas_test.py b/tests/pallas/tpu_sparsecore_pallas_test.py index 5f5c1726ffba..6e3e1a3b215d 100644 --- a/tests/pallas/tpu_sparsecore_pallas_test.py +++ b/tests/pallas/tpu_sparsecore_pallas_test.py @@ -296,7 +296,13 @@ def kernel(x_ref, o_ref): @hp.given(hps.data()) def test_block_spec_untiled_slicing(self, data): if not self.USE_TC_TILING: - self.skipTest("Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')") + self.skipTest( + "Test uncovers a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')" + ) + else: + self.skipTest( + "Test uncovers a bug: @reproduce_failure('6.80.0', b'AAEAAQAAAAA=')" + ) slice_shape = data.draw( hps.lists( hps.integers(1, 3), min_size=(1 + self.USE_TC_TILING), max_size=4 @@ -1065,7 +1071,7 @@ def scoped_kernel(scratch_ref): pl.run_scoped( scoped_kernel, plsc.MemoryRef( - x.shape, x_ref.dtype, memory_space=pltpu.VMEM, tiling=[(2, 1)] + x.shape, x_ref.dtype, memory_space=pltpu.VMEM, tiling=[(1, 8)] ), )