Skip to content

Commit ed4d825

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix spmd_axis_name assert with explicit_mesh_axis in presence of multi-character mesh axis name
PiperOrigin-RevId: 845351158
1 parent bb33967 commit ed4d825

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

jax/_src/api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,9 +1191,9 @@ def vmap_f(*args, **kwargs):
11911191
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
11921192
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
11931193
if spmd_axis_name is not None and explicit_mesh_axis is not None:
1194-
spmd_axis_name = (
1195-
tuple(*core.remove_size_one_mesh_axis(P(spmd_axis_name), get_abstract_mesh()))
1196-
if config.remove_size_one_mesh_axis_from_type.value else spmd_axis_name)
1194+
if config.remove_size_one_mesh_axis_from_type.value:
1195+
mesh = get_abstract_mesh()
1196+
spmd_axis_name = tuple(i for i in spmd_axis_name if mesh.shape[i] != 1)
11971197
if spmd_axis_name == explicit_mesh_axis:
11981198
spmd_axis_name = None
11991199
else:

tests/pjit_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7275,11 +7275,11 @@ def f(x):
72757275
f(arr)
72767276

72777277
@parameterized.parameters(
7278-
(('x', 'y', 'z'), ('x', 'y')),
7279-
(('x', 'z'), 'x')
7278+
(('data', 'model', 'stage'), ('data', 'model')),
7279+
(('data', 'stage'), 'data')
72807280
)
72817281
@config.remove_size_one_mesh_axis_from_type(True)
7282-
@jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z'))
7282+
@jtu.with_explicit_mesh((2, 2, 1), ('data', 'model', 'stage'))
72837283
def test_spmd_axis_name_explicit_mode_assert_remove_one_size(
72847284
self, in_spec, out_spec, mesh):
72857285
np_inp = np.arange(16).reshape(4, 2, 2)

0 commit comments

Comments
 (0)