diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 41a4b99ed944..e41d29a53cf8 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -813,7 +813,7 @@ def func(x): # b: f32[2, 4] # the expected custom call targets for old test data that was serialized # with custom calls. for data, custom_call_targets_override in data: - with mesh: + with jax.set_mesh(mesh): if jax.config.jax_use_shardy_partitioner: self.run_one_test( func, self.load_testdata(data["shardy"]), @@ -1040,7 +1040,7 @@ def shard_map_func(x): # b: f32[2, 4] # the expected custom call targets for old test data that was serialized # with custom calls. for data, custom_call_targets_override in data: - with Mesh(devices, axis_names=('x')): + with jax.set_mesh(Mesh(devices, axis_names=('x'))): self.run_one_test( func, self.load_testdata(data), expect_current_custom_calls=custom_call_targets_override)