Skip to content

Commit b44a511

Browse files
gneculaGoogle-ML-Automation
authored andcommitted
Reverts 58fbb8a
PiperOrigin-RevId: 845733949
1 parent b8d3757 commit b44a511

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/export_back_compat_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ def func(x): # b: f32[2, 4]
813813
# the expected custom call targets for old test data that was serialized
814814
# with custom calls.
815815
for data, custom_call_targets_override in data:
816-
with jax.set_mesh(mesh):
816+
with mesh:
817817
if jax.config.jax_use_shardy_partitioner:
818818
self.run_one_test(
819819
func, self.load_testdata(data["shardy"]),
@@ -1040,7 +1040,7 @@ def shard_map_func(x): # b: f32[2, 4]
10401040
# the expected custom call targets for old test data that was serialized
10411041
# with custom calls.
10421042
for data, custom_call_targets_override in data:
1043-
with jax.set_mesh(Mesh(devices, axis_names=('x'))):
1043+
with Mesh(devices, axis_names=('x')):
10441044
self.run_one_test(
10451045
func, self.load_testdata(data),
10461046
expect_current_custom_calls=custom_call_targets_override)

0 commit comments

Comments
 (0)