Skip to content

Commit 96fb4bc

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:sc] Skipped a few more tests under TC tiling due to verification errors
PiperOrigin-RevId: 833257620
1 parent ecd0b33 commit 96fb4bc

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,8 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
401401
np.testing.assert_array_equal(kernel(x, indices)[0], x[indices])
402402

403403
def test_large_gather_1d(self):
404+
self.skip_if_tc_tiling()
405+
404406
x = jnp.arange(1024)
405407
indices = jax.random.permutation(jax.random.key(42), x)
406408

@@ -944,6 +946,7 @@ def kernel(x_ref, o_ref):
944946
np.testing.assert_array_equal(kernel(x)[5:13:2], x[2:6])
945947

946948
def test_scalar_load_store(self):
949+
self.skip_if_tc_tiling()
947950

948951
@self.vector_subcore_kernel(
949952
in_specs=(pl.BlockSpec(memory_space=pltpu.HBM),),
@@ -1001,6 +1004,8 @@ def kernel(x_ref, o_hbm_ref):
10011004
np.testing.assert_array_equal(kernel(x), x)
10021005

10031006
def test_run_scoped_with_tiling(self):
1007+
self.skip_if_tc_tiling()
1008+
10041009
x = jnp.arange(2 * 8).reshape(-1, 8)
10051010

10061011
@self.vector_subcore_kernel(out_shape=x)
@@ -1036,6 +1041,8 @@ def kernel(x_ref, o_ref):
10361041
np.testing.assert_array_equal(kernel(x), x)
10371042

10381043
def test_scratch(self):
1044+
self.skip_if_tc_tiling()
1045+
10391046
x = jnp.arange(8)
10401047

10411048
@self.vector_subcore_kernel(

0 commit comments

Comments
 (0)