Skip to content

Commit 58fbb8a

Browse files
Merge pull request #33885 from gnecula:fix_tests
PiperOrigin-RevId: 845285515
2 parents 21b8652 + 3cbe137 commit 58fbb8a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/export_back_compat_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def func(x): # b: f32[2, 4]
813813
# the expected custom call targets for old test data that was serialized
814814
# with custom calls.
815815
for data, custom_call_targets_override in data:
816-
with mesh:
816+
with jax.set_mesh(mesh):
817817
if jax.config.jax_use_shardy_partitioner:
818818
self.run_one_test(
819819
func, self.load_testdata(data["shardy"]),
@@ -1040,7 +1040,7 @@ def shard_map_func(x): # b: f32[2, 4]
10401040
# the expected custom call targets for old test data that was serialized
10411041
# with custom calls.
10421042
for data, custom_call_targets_override in data:
1043-
with Mesh(devices, axis_names=('x')):
1043+
with jax.set_mesh(Mesh(devices, axis_names=('x'))):
10441044
self.run_one_test(
10451045
func, self.load_testdata(data),
10461046
expect_current_custom_calls=custom_call_targets_override)

0 commit comments

Comments
 (0)