Skip to content

Commit 3a83fa4

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
Re-enable TPU tests now that TPU thread stack sizes have been increased.
Reverts 4d9ff5b PiperOrigin-RevId: 840845684
1 parent bd63099 commit 3a83fa4

File tree

4 files changed

+0
-31
lines changed

4 files changed

+0
-31
lines changed

tests/multiprocess/pjit_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,6 @@ def f(x):
527527
self.assertEqual(output(), "")
528528

529529
def test_print_in_multihost_shard_map(self):
530-
if jtu.is_cloud_tpu():
531-
self.skipTest("TODO: b/465504705")
532-
533530
devices = jax.devices()
534531
mesh = jax.sharding.Mesh(devices, ("i",))
535532
num_devices = jax.local_device_count()

tests/pallas/ops_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2756,9 +2756,6 @@ class OpsInterpretTest(OpsTest):
27562756
INTERPRET = True
27572757

27582758
def test_debug_print(self):
2759-
if jtu.is_cloud_tpu():
2760-
self.skipTest("TODO: b/465504705")
2761-
27622759
@functools.partial(
27632760
self.pallas_call,
27642761
out_shape=jax.ShapeDtypeStruct((2,), jnp.float32),

tests/pjit_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4247,9 +4247,6 @@ def test_in_out_shardings_unconstrained_error(self):
42474247
in_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'x')))
42484248

42494249
def test_empty_io_callback_under_shard_map(self):
4250-
if jtu.is_cloud_tpu():
4251-
self.skipTest("TODO: b/465504705")
4252-
42534250
mesh = jtu.create_mesh((4,), 'i')
42544251

42554252
def empty_callback(x):

tests/python_callback_test.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
)
5454

5555

56-
@unittest.skipIf(jtu.is_cloud_tpu(), "TODO: b/465504705")
5756
class PythonCallbackTest(jtu.JaxTestCase):
5857

5958
def setUp(self):
@@ -670,7 +669,6 @@ def f(x):
670669
np.testing.assert_array_equal(x, result)
671670

672671

673-
@unittest.skipIf(jtu.is_cloud_tpu(), "TODO: b/465504705")
674672
class PureCallbackTest(jtu.JaxTestCase):
675673

676674
def setUp(self):
@@ -1152,9 +1150,6 @@ def tearDown(self):
11521150
dispatch.runtime_tokens.clear()
11531151

11541152
def test_io_callback_can_mutate_state(self):
1155-
if jtu.is_cloud_tpu():
1156-
self.skipTest("TODO: b/465504705")
1157-
11581153
x = 0
11591154
def cb():
11601155
nonlocal x
@@ -1171,9 +1166,6 @@ def f():
11711166
self.assertEqual(x, 2)
11721167

11731168
def test_io_callback_can_be_batched_if_unordered(self):
1174-
if jtu.is_cloud_tpu():
1175-
self.skipTest("TODO: b/465504705")
1176-
11771169
_mut = 0
11781170
def cb(x):
11791171
nonlocal _mut
@@ -1282,9 +1274,6 @@ def f(x, y):
12821274
def test_can_use_io_callback_in_pjit(
12831275
self, *, ordered: bool, with_sharding: bool
12841276
):
1285-
if jtu.is_cloud_tpu():
1286-
self.skipTest("TODO: b/465504705")
1287-
12881277
devices = jax.devices()
12891278
mesh = jax.sharding.Mesh(np.array(devices), ['dev'])
12901279

@@ -1345,9 +1334,6 @@ def f(x):
13451334
@jtu.ignore_warning(message='.*Please use `jax.jit` instead.*',
13461335
category=DeprecationWarning)
13471336
def test_sequence_pjit_io_callback_ordered(self):
1348-
if jtu.is_cloud_tpu():
1349-
self.skipTest("TODO: b/465504705")
1350-
13511337
if jtu.is_device_tpu(7, 'x'):
13521338
self.skipTest('TODO(b/453664256): Failing on TPU 7x.')
13531339

@@ -1409,8 +1395,6 @@ def f_base(i, x):
14091395
single_device=True)
14101396
)
14111397
def test_can_shard_io_callback_manually(self, single_device: bool):
1412-
if jtu.is_cloud_tpu():
1413-
self.skipTest("TODO: b/465504705")
14141398

14151399
devices = jax.devices()
14161400
if single_device:
@@ -1445,9 +1429,6 @@ def f(shard_ids, x):
14451429

14461430
def test_batching_with_side_effects(self):
14471431
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195
1448-
if jtu.is_cloud_tpu():
1449-
self.skipTest("TODO: b/465504705")
1450-
14511432
x_lst = []
14521433
def append_x(x):
14531434
nonlocal x_lst
@@ -1464,9 +1445,6 @@ def f(x):
14641445

14651446
def test_batching_with_side_effects_while_loop(self):
14661447
# https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219
1467-
if jtu.is_cloud_tpu():
1468-
self.skipTest("TODO: b/465504705")
1469-
14701448
x_lst = []
14711449
def append_x(x):
14721450
nonlocal x_lst

0 commit comments

Comments
 (0)