Skip to content

Commit 0d95c2b

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:sc] Skipped a few more tests under TC tiling due to verification errors
PiperOrigin-RevId: 833257620
1 parent dd62dd7 commit 0d95c2b

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

jaxlib/mosaic/dialect/tpu/tpu_dialect.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ std::pair<bool, bool> mightCommunicateBetweenChips(Operation *op);
5959

6060
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
6161
int hardware_generation, std::array<int64_t, 2> target_shape,
62-
const TpuTilingFlags& tpu_tiling_flags, bool align = true);
62+
const TpuTilingFlags& tpu_tiling_flags, bool align = true,
63+
bool infer_kernel_arguments = true);
6364

6465
#define GEN_PASS_DECL_MOSAICSERDEPASS
6566
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,10 @@ def kernel(x_ref, o_ref):
297297
def test_block_spec_untiled_slicing(self, data):
298298
if not self.USE_TC_TILING:
299299
self.skipTest("Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')")
300+
else:
301+
self.skipTest(
302+
"Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEAAQAAAAA=')"
303+
)
300304
slice_shape = data.draw(
301305
hps.lists(
302306
hps.integers(1, 3), min_size=(1 + self.USE_TC_TILING), max_size=4
@@ -438,6 +442,8 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
438442
np.testing.assert_array_equal(kernel(x, indices)[0], x[indices])
439443

440444
def test_large_gather_1d(self):
445+
self.skip_if_tc_tiling()
446+
441447
x = jnp.arange(1024)
442448
indices = jax.random.permutation(jax.random.key(42), x)
443449

@@ -471,6 +477,8 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
471477
np.testing.assert_array_equal(kernel(x, indices), x[1, 2, indices])
472478

473479
def test_gather_2d_with_indexing(self):
480+
self.skip_if_tc_tiling()
481+
474482
x = jnp.arange(4 * 16 * 128).reshape(4, 16, 128)
475483
indices = jax.random.permutation(jax.random.key(42), jnp.arange(8))
476484

@@ -997,6 +1005,7 @@ def kernel(x_ref, o_ref):
9971005
np.testing.assert_array_equal(kernel(x)[5:13:2], x[2:6])
9981006

9991007
def test_scalar_load_store(self):
1008+
self.skip_if_tc_tiling()
10001009

10011010
@self.vector_subcore_kernel(
10021011
in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),),
@@ -1054,6 +1063,8 @@ def kernel(x_ref, o_hbm_ref):
10541063
np.testing.assert_array_equal(kernel(x), x)
10551064

10561065
def test_run_scoped_with_tiling(self):
1066+
self.skip_if_tc_tiling()
1067+
10571068
x = jnp.arange(2 * 8).reshape(-1, 8)
10581069

10591070
@self.vector_subcore_kernel(out_shape=x)
@@ -1089,6 +1100,8 @@ def kernel(x_ref, o_ref):
10891100
np.testing.assert_array_equal(kernel(x), x)
10901101

10911102
def test_scratch(self):
1103+
self.skip_if_tc_tiling()
1104+
10921105
x = jnp.arange(8)
10931106

10941107
@self.vector_subcore_kernel(

0 commit comments

Comments
 (0)