diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 71f063135001..ca0e2b1f0a8f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2696,7 +2696,7 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, operand: an array shape: the shape of the target array broadcast_dimensions: to which dimension in the target shape each dimension - of the operand shape corresponds to. That is, dimension i of the operand + of the operand shape corresponds to. That is, dimension i of the operand becomes dimension broadcast_dimensions[i] of the result. Returns: @@ -2705,18 +2705,15 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ + # TODO(dfm): Re-write this as a "reshard" when only the sharding changes. out_sharding = canonicalize_sharding(out_sharding, 'broadcast_in_dim') if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): return operand - operand_aval = core.typeof(operand) - if (operand_aval.shape == shape and - list(broadcast_dimensions) == list(range(operand_aval.ndim)) and - out_sharding is not None and operand_aval.sharding != out_sharding): - return pjit.reshard(operand, out_sharding) return broadcast_in_dim_p.bind( operand, shape=tuple(shape), - broadcast_dimensions=tuple(broadcast_dimensions), sharding=out_sharding) + broadcast_dimensions=tuple(broadcast_dimensions), + sharding=out_sharding) def broadcast_to_rank(x: ArrayLike, rank: int) -> Array: """Adds leading dimensions of ``1`` to give ``x`` rank ``rank``.""" diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8c71d5132667..243772588c91 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5399,23 +5399,6 @@ def f(x, y): ValueError, "For primitive.*context mesh.*aval mesh"): f(arr1, arr2) - @jtu.with_explicit_mesh((2,), 'x') - def test_no_op_broadcast_except_for_sharding_change(self, mesh): - arr = jnp.arange(8.).reshape(4, 2) - - @jax.jit - def f(x): - out = jax.lax.broadcast_in_dim(x, (4, 2), [0, 1], out_sharding=P('x')) - self.assertEqual(out.aval.sharding.spec, P('x', None)) - return out - - out = f(arr) - self.assertArraysEqual(out, arr) - self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) - - out_g = jax.jit(jax.grad(lambda x: f(x).sum()))(arr) - self.assertEqual(out_g.sharding, NamedSharding(mesh, P(None, None))) - @jtu.with_explicit_mesh((2, 2), ('x', 'y')) def test_sin_unop(self, mesh): np_inp = np.arange(16.).reshape(8, 2)