Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2696,7 +2696,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
operand: an array
shape: the shape of the target array
broadcast_dimensions: to which dimension in the target shape each dimension
of the operand shape corresponds to. That is, dimension i of the operand
of the operand shape corresponds to. That is, dimension i of the operand
becomes dimension broadcast_dimensions[i] of the result.

Returns:
Expand All @@ -2705,18 +2705,15 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
See Also:
jax.lax.broadcast : simpler interface to add new leading dimensions.
"""
# TODO(dfm): Re-write this as a "reshard" when only the sharding changes.
out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim')
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
isinstance(operand, Array) and out_sharding is None):
return operand
operand_aval = core.typeof(operand)
if (operand_aval.shape == shape and
list(broadcast_dimensions) == list(range(operand_aval.ndim)) and
out_sharding is not None and operand_aval.sharding != out_sharding):
return pjit.reshard(operand, out_sharding)
return broadcast_in_dim_p.bind(
operand, shape=tuple(shape),
broadcast_dimensions=tuple(broadcast_dimensions), sharding=out_sharding)
broadcast_dimensions=tuple(broadcast_dimensions),
sharding=out_sharding)

def broadcast_to_rank(x: ArrayLike, rank: int) -> Array:
"""Adds leading dimensions of ``1`` to give ``x`` rank ``rank``."""
Expand Down
17 changes: 0 additions & 17 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5399,23 +5399,6 @@ def f(x, y):
ValueError, "For primitive.*context mesh.*aval mesh"):
f(arr1, arr2)

@jtu.with_explicit_mesh((2,), 'x')
def test_no_op_broadcast_except_for_sharding_change(self, mesh):
arr = jnp.arange(8.).reshape(4, 2)

@jax.jit
def f(x):
out = jax.lax.broadcast_in_dim(x, (4, 2), [0, 1], out_sharding=P('x'))
self.assertEqual(out.aval.sharding.spec, P('x', None))
return out

out = f(arr)
self.assertArraysEqual(out, arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr)
self.assertEqual(out_g.sharding, NamedSharding(mesh, P(None, None)))

@jtu.with_explicit_mesh((2, 2), ('x', 'y'))
def test_sin_unop(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
Expand Down
Loading