5353)
5454
5555
56- @unittest .skipIf (jtu .is_cloud_tpu (), "TODO: b/465504705" )
5756class PythonCallbackTest (jtu .JaxTestCase ):
5857
5958 def setUp (self ):
@@ -670,7 +669,6 @@ def f(x):
670669 np .testing .assert_array_equal (x , result )
671670
672671
673- @unittest .skipIf (jtu .is_cloud_tpu (), "TODO: b/465504705" )
674672class PureCallbackTest (jtu .JaxTestCase ):
675673
676674 def setUp (self ):
@@ -1152,9 +1150,6 @@ def tearDown(self):
11521150 dispatch .runtime_tokens .clear ()
11531151
11541152 def test_io_callback_can_mutate_state (self ):
1155- if jtu .is_cloud_tpu ():
1156- self .skipTest ("TODO: b/465504705" )
1157-
11581153 x = 0
11591154 def cb ():
11601155 nonlocal x
@@ -1171,9 +1166,6 @@ def f():
11711166 self .assertEqual (x , 2 )
11721167
11731168 def test_io_callback_can_be_batched_if_unordered (self ):
1174- if jtu .is_cloud_tpu ():
1175- self .skipTest ("TODO: b/465504705" )
1176-
11771169 _mut = 0
11781170 def cb (x ):
11791171 nonlocal _mut
@@ -1282,9 +1274,6 @@ def f(x, y):
12821274 def test_can_use_io_callback_in_pjit (
12831275 self , * , ordered : bool , with_sharding : bool
12841276 ):
1285- if jtu .is_cloud_tpu ():
1286- self .skipTest ("TODO: b/465504705" )
1287-
12881277 devices = jax .devices ()
12891278 mesh = jax .sharding .Mesh (np .array (devices ), ['dev' ])
12901279
@@ -1345,9 +1334,6 @@ def f(x):
13451334 @jtu .ignore_warning (message = '.*Please use `jax.jit` instead.*' ,
13461335 category = DeprecationWarning )
13471336 def test_sequence_pjit_io_callback_ordered (self ):
1348- if jtu .is_cloud_tpu ():
1349- self .skipTest ("TODO: b/465504705" )
1350-
13511337 if jtu .is_device_tpu (7 , 'x' ):
13521338 self .skipTest ('TODO(b/453664256): Failing on TPU 7x.' )
13531339
@@ -1409,8 +1395,6 @@ def f_base(i, x):
14091395 single_device = True )
14101396 )
14111397 def test_can_shard_io_callback_manually (self , single_device : bool ):
1412- if jtu .is_cloud_tpu ():
1413- self .skipTest ("TODO: b/465504705" )
14141398
14151399 devices = jax .devices ()
14161400 if single_device :
@@ -1445,9 +1429,6 @@ def f(shard_ids, x):
14451429
14461430 def test_batching_with_side_effects (self ):
14471431 # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195
1448- if jtu .is_cloud_tpu ():
1449- self .skipTest ("TODO: b/465504705" )
1450-
14511432 x_lst = []
14521433 def append_x (x ):
14531434 nonlocal x_lst
@@ -1464,9 +1445,6 @@ def f(x):
14641445
14651446 def test_batching_with_side_effects_while_loop (self ):
14661447 # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219
1467- if jtu .is_cloud_tpu ():
1468- self .skipTest ("TODO: b/465504705" )
1469-
14701448 x_lst = []
14711449 def append_x (x ):
14721450 nonlocal x_lst
0 commit comments