@@ -401,6 +401,8 @@ def kernel(x_hbm_ref, indices_ref, o_ref):
401401 np .testing .assert_array_equal (kernel (x , indices )[0 ], x [indices ])
402402
403403 def test_large_gather_1d (self ):
404+ self .skip_if_tc_tiling ()
405+
404406 x = jnp .arange (1024 )
405407 indices = jax .random .permutation (jax .random .key (42 ), x )
406408
@@ -944,6 +946,7 @@ def kernel(x_ref, o_ref):
944946 np .testing .assert_array_equal (kernel (x )[5 :13 :2 ], x [2 :6 ])
945947
946948 def test_scalar_load_store (self ):
949+ self .skip_if_tc_tiling ()
947950
948951 @self .vector_subcore_kernel (
949952 in_specs = (pl .BlockSpec (memory_space = pltpu .HBM ),),
@@ -1001,6 +1004,8 @@ def kernel(x_ref, o_hbm_ref):
10011004 np .testing .assert_array_equal (kernel (x ), x )
10021005
10031006 def test_run_scoped_with_tiling (self ):
1007+ self .skip_if_tc_tiling ()
1008+
10041009 x = jnp .arange (2 * 8 ).reshape (- 1 , 8 )
10051010
10061011 @self .vector_subcore_kernel (out_shape = x )
@@ -1036,6 +1041,8 @@ def kernel(x_ref, o_ref):
10361041 np .testing .assert_array_equal (kernel (x ), x )
10371042
10381043 def test_scratch (self ):
1044+ self .skip_if_tc_tiling ()
1045+
10391046 x = jnp .arange (8 )
10401047
10411048 @self .vector_subcore_kernel (
0 commit comments