@@ -784,8 +784,6 @@ def func(x):
784784 data = self .load_testdata (cuda_threefry2x32 .data_2024_07_30 )
785785 self .run_one_test (func , data )
786786
787- @jtu .ignore_warning (category = DeprecationWarning ,
788- message = '`with mesh:` context manager' )
789787 def test_tpu_sharding (self ):
790788 # Tests "Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape" on TPU
791789 if not jtu .test_device_matches (["tpu" ]) or len (jax .devices ()) < 2 :
@@ -815,7 +813,7 @@ def func(x): # b: f32[2, 4]
815813 # the expected custom call targets for old test data that was serialized
816814 # with custom calls.
817815 for data , custom_call_targets_override in data :
818- with mesh :
816+ with jax . set_mesh ( mesh ) :
819817 if jax .config .jax_use_shardy_partitioner :
820818 self .run_one_test (
821819 func , self .load_testdata (data ["shardy" ]),
@@ -1010,8 +1008,6 @@ def check_top_k_results(res_run, res_expected, *, rtol, atol):
10101008
10111009class ShardyCompatTest (bctu .CompatTestBase ):
10121010
1013- @jtu .ignore_warning (category = DeprecationWarning ,
1014- message = '`with mesh:` context manager' )
10151011 def test_shardy_sharding_ops_with_different_meshes (self ):
10161012 # Tests whether we can save and load a module with meshes that have the
10171013 # same axis sizes (and same order) but different axis names.
@@ -1044,7 +1040,7 @@ def shard_map_func(x): # b: f32[2, 4]
10441040 # the expected custom call targets for old test data that was serialized
10451041 # with custom calls.
10461042 for data , custom_call_targets_override in data :
1047- with Mesh (devices , axis_names = ('x' )):
1043+ with jax . set_mesh ( Mesh (devices , axis_names = ('x' ) )):
10481044 self .run_one_test (
10491045 func , self .load_testdata (data ),
10501046 expect_current_custom_calls = custom_call_targets_override )
0 commit comments