@@ -297,6 +297,10 @@ def kernel(x_ref, o_ref):
297297 def test_block_spec_untiled_slicing (self , data ):
298298 if not self .USE_TC_TILING :
299299 self .skipTest ("Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEBAQAAAAA=')" )
300+ else :
301+ self .skipTest (
302+ "Test uncoveres a bug: @reproduce_failure('6.80.0', b'AAEAAQAAAAA=')"
303+ )
300304 slice_shape = data .draw (
301305 hps .lists (
302306 hps .integers (1 , 3 ), min_size = (1 + self .USE_TC_TILING ), max_size = 4
@@ -438,6 +442,8 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
438442 np .testing .assert_array_equal (kernel (x , indices )[0 ], x [indices ])
439443
440444 def test_large_gather_1d (self ):
445+ self .skip_if_tc_tiling ()
446+
441447 x = jnp .arange (1024 )
442448 indices = jax .random .permutation (jax .random .key (42 ), x )
443449
@@ -471,6 +477,8 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
471477 np .testing .assert_array_equal (kernel (x , indices ), x [1 , 2 , indices ])
472478
473479 def test_gather_2d_with_indexing (self ):
480+ self .skip_if_tc_tiling ()
481+
474482 x = jnp .arange (4 * 16 * 128 ).reshape (4 , 16 , 128 )
475483 indices = jax .random .permutation (jax .random .key (42 ), jnp .arange (8 ))
476484
@@ -997,6 +1005,7 @@ def kernel(x_ref, o_ref):
9971005 np .testing .assert_array_equal (kernel (x )[5 :13 :2 ], x [2 :6 ])
9981006
9991007 def test_scalar_load_store (self ):
1008+ self .skip_if_tc_tiling ()
10001009
10011010 @self .vector_subcore_kernel (
10021011 in_specs = (pl .BlockSpec (memory_space = pltpu .HBM ),),
@@ -1054,6 +1063,8 @@ def kernel(x_ref, o_hbm_ref):
10541063 np .testing .assert_array_equal (kernel (x ), x )
10551064
10561065 def test_run_scoped_with_tiling (self ):
1066+ self .skip_if_tc_tiling ()
1067+
10571068 x = jnp .arange (2 * 8 ).reshape (- 1 , 8 )
10581069
10591070 @self .vector_subcore_kernel (out_shape = x )
@@ -1089,6 +1100,8 @@ def kernel(x_ref, o_ref):
10891100 np .testing .assert_array_equal (kernel (x ), x )
10901101
10911102 def test_scratch (self ):
1103+ self .skip_if_tc_tiling ()
1104+
10921105 x = jnp .arange (8 )
10931106
10941107 @self .vector_subcore_kernel (
0 commit comments