Skip to content

Commit 0bb901b

Browse files
committed
[export] Fix the "with mesh" deprecation warning
1 parent 03256b3 commit 0bb901b

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

tests/export_back_compat_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,6 @@ def func(x):
784784
data = self.load_testdata(cuda_threefry2x32.data_2024_07_30)
785785
self.run_one_test(func, data)
786786

787-
@jtu.ignore_warning(category=DeprecationWarning,
788-
message='`with mesh:` context manager')
789787
def test_tpu_sharding(self):
790788
# Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
791789
if not jtu.test_device_matches(["tpu"]) or len(jax.devices()) < 2:
@@ -815,7 +813,7 @@ def func(x): # b: f32[2, 4]
815813
# the expected custom call targets for old test data that was serialized
816814
# with custom calls.
817815
for data, custom_call_targets_override in data:
818-
with mesh:
816+
with jax.set_mesh(mesh):
819817
if jax.config.jax_use_shardy_partitioner:
820818
self.run_one_test(
821819
func, self.load_testdata(data["shardy"]),
@@ -1010,8 +1008,6 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol):
10101008

10111009
class ShardyCompatTest(bctu.CompatTestBase):
10121010

1013-
@jtu.ignore_warning(category=DeprecationWarning,
1014-
message='`with mesh:` context manager')
10151011
def test_shardy_sharding_ops_with_different_meshes(self):
10161012
# Tests whether we can save and load a module with meshes that have the
10171013
# same axis sizes (and same order) but different axis names.
@@ -1044,7 +1040,7 @@ def shard_map_func(x): # b: f32[2, 4]
10441040
# the expected custom call targets for old test data that was serialized
10451041
# with custom calls.
10461042
for data, custom_call_targets_override in data:
1047-
with Mesh(devices, axis_names=('x')):
1043+
with jax.set_mesh(Mesh(devices, axis_names=('x'))):
10481044
self.run_one_test(
10491045
func, self.load_testdata(data),
10501046
expect_current_custom_calls=custom_call_targets_override)

0 commit comments

Comments
 (0)