Skip to content

Commit 6940903

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Dispatch to reshard in broadcast_in_dim if only sharding is changing and everything else is the same as operand.
PiperOrigin-RevId: 845590697
1 parent 4e1d8f9 commit 6940903

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

jax/_src/lax/lax.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2696,7 +2696,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
26962696
operand: an array
26972697
shape: the shape of the target array
26982698
broadcast_dimensions: to which dimension in the target shape each dimension
2699-
of the operand shape corresponds to. That is, dimension i of the operand
2699+
of the operand shape corresponds to. That is, dimension i of the operand
27002700
becomes dimension broadcast_dimensions[i] of the result.
27012701
27022702
Returns:
@@ -2705,15 +2705,18 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
27052705
See Also:
27062706
jax.lax.broadcast : simpler interface to add new leading dimensions.
27072707
"""
2708-
# TODO(dfm): Re-write this as a "reshard" when only the sharding changes.
27092708
out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim')
27102709
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
27112710
isinstance(operand, Array) and out_sharding is None):
27122711
return operand
2712+
operand_aval = core.typeof(operand)
2713+
if (operand_aval.shape == shape and
2714+
list(broadcast_dimensions) == list(range(operand_aval.ndim)) and
2715+
out_sharding is not None and operand_aval.sharding != out_sharding):
2716+
return pjit.reshard(operand, out_sharding)
27132717
return broadcast_in_dim_p.bind(
27142718
operand, shape=tuple(shape),
2715-
broadcast_dimensions=tuple(broadcast_dimensions),
2716-
sharding=out_sharding)
2719+
broadcast_dimensions=tuple(broadcast_dimensions), sharding=out_sharding)
27172720

27182721
def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
27192722
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""

tests/pjit_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5399,6 +5399,23 @@ def f(x, y):
53995399
ValueError, "For primitive.*context mesh.*aval mesh"):
54005400
f(arr1, arr2)
54015401

5402+
@jtu.with_explicit_mesh((2,), 'x')
5403+
def test_no_op_broadcast_except_for_sharding_change(self, mesh):
5404+
arr = jnp.arange(8.).reshape(4, 2)
5405+
5406+
@jax.jit
5407+
def f(x):
5408+
out = jax.lax.broadcast_in_dim(x, (4, 2), [0, 1], out_sharding=P('x'))
5409+
self.assertEqual(out.aval.sharding.spec, P('x', None))
5410+
return out
5411+
5412+
out = f(arr)
5413+
self.assertArraysEqual(out, arr)
5414+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
5415+
5416+
out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr)
5417+
self.assertEqual(out_g.sharding, NamedSharding(mesh, P(None, None)))
5418+
54025419
@jtu.with_explicit_mesh((2, 2), ('x', 'y'))
54035420
def test_sin_unop(self, mesh):
54045421
np_inp = np.arange(16.).reshape(8, 2)

0 commit comments

Comments
 (0)