@@ -92,6 +92,49 @@ Mesh('y': 4, axis_types=(Auto,))
9292
9393## Performance implications
9494
95+ ### ` int ` indexing into sharded arrays
96+
97+ The new implementation of ` jax.pmap ` uses ` NamedSharding ` instead of the legacy
98+ ` PmapSharding ` . We've observe a common pattern with the old ` jax.pmap ` where
99+ users shard stacked copies of an array to replicate (e.g., via
100+ ` jax.device_put_replicated ` ). These "sharded-but-really-replicated" arrays
101+ suffer unnecessary communication overhead when ` int ` indexing (e.g., ` x[0] ` )
102+ because JAX does not know the arrays are actually replicated. For a more
103+ thorough discussion, please see [ Appendix A] ( #appendix-a ) .
104+
105+ #### Option 1: Prevent unintended sharding (recommended)
106+ Avoid creating the leading sharded dimension entirely.
107+
108+ - Use ` jax.pmap ` 's ` out_axes=None ` for arguments that should remain replicated.
109+ The output will be fully replicated (e.g., ` P(None, None) ` ), making access
110+ cheap.
111+ - For inputs: When using ` jax.device_put ` , specify ` jax.P() ` (fully replicated)
112+ in the partition spec rather than relying on utilities that stack and shard.
113+ (Note: ` jax.device_put_replicated ` and ` jax.device_put_sharded ` are deprecated
114+ because they confusingly produce sharded arrays rather than replicated ones).
115+
116+ #### Option 2: Access local data directly
117+ If you must work with a sharded array (or want potentially fewer changes to
118+ code), you can access the local data shard directly without triggering JAX's
119+ distributed consistency checks. Note that this is only recommended when bringing
120+ data back to host (e.g., for logging, checkpointing). Instead of ` x[0] ` , use
121+ ` addressable_shards ` :
122+
123+ ``` python
124+ # Old slow way:
125+ # result = x[0]
126+
127+ # New fast way:
128+ # x.addressable_shards is a list of shards on the current process.
129+ # We grab the first one, extract the data, and remove the leading dimension.
130+ result = x.addressable_shards[0 ].data.squeeze(0 )
131+ ```
132+
133+ In the example of ` x ` with shape ` (8, 3, 4) ` , ` x.addressable_shards[0].data `
134+ returns the local chunk of shape ` (1, 3, 4) ` . Calling ` .squeeze(0) ` results in
135+ the desired ` (3, 4) ` shape without any cross-device communication. Both
136+ solutions will eliminate the ` _gather ` operations seen in profiling.
137+
95138### Host local array to global array round-trip conversion
96139
97140In multi-process JAX programs (i.e., ` jax.process_count() > 1 ` ), arrays might be
@@ -104,23 +147,6 @@ host-local array when returning to user code.
104147This round-trip conversion cannot be avoided, so if the performance penalty is
105148too great, we recommend migrating your code to ` jax.shard_map ` .
106149
107- ### ` int ` array indexing
108-
109- Indexing into a sharded array with an int (e.g., ` arr[0] ` ) may now execute a
110- rank reduction computation. Depending on your use case, there may be
111- workarounds:
112-
113- 1 . In a typical training loop, we might use a ` jax.pmap ` ed update function to
114- operate on / carry training state and grab resulting metrics from the first
115- ` jax.pmap ` 'ed device for logging. In this case, it may be possible to
116- use ` None ` for the relevant ` in_axes ` and ` out_axes ` passed to ` jax.pmap ` .
117- This lets ` jax.pmap ` handle replication and will return an
118- appropriately-shaped result that looks like it's from a single device for,
119- say, logging metrics.
120- 2 . More generally, you can get the first shard of data without a reshape via
121- ` arr[0:1] ` or ` arr.addressable_shards[0].data ` . Note that this will have a
122- leading ` (1,) ` dimension that your code will need to handle.
123-
124150## Migrating to ` jax.shard_map `
125151
126152In many cases, users can migrate from ` jax.pmap ` to ` jax.jit(jax.shard_map) ` by
@@ -132,4 +158,111 @@ dispatch path as in the `jax.shard_map` implementation of `jax.pmap` and can
132158often be overlapped with compute or be called infrequently (i.e., before a train
133159loop and for occasionally grabbing metrics).
134160
161+ (appendix-a)=
162+ ## Appendix A: More details about ` int ` indexing into sharded arrays.
163+
164+ ### What should ` x[0] ` return?
165+
166+ In ** NumPy** , ` x[0] ` returns a rank-reduced array representing the first slice
167+ along the first dimension. For example, if ` x = np.ones((8, 3, 4)) ` , then ` x[0] `
168+ returns an array of shape ` (3, 4) ` .
169+
170+ In ** JAX** (` jax.numpy ` ), ` x[0] ` semantically works the same way: it returns the
171+ rank-reduced slice of the logical array ` x ` . However, performance depends on how
172+ ` x ` is sharded or replicated across devices. Consider an array ` x ` with shape
173+ ` (8, 3, 4) ` distributed across 8 devices (using ` jax.P ` as the short name for
174+ ` jax.sharding.PartitionSpec ` P):
175+
176+ 1 . ** Fully Replicated:** ` jax.P(None, None, None) `
177+ If ` x ` is fully replicated, every device holds a complete copy of the `(8,
178+ 3, 4)` array. ` x[ 0] ` will have the shape ` (3, 4)` and a partition spec
179+ ` jax.P(None, None) ` . Since every device already has ` x ` , this operation will
180+ slice on each device independently and requires ** no communication** .
181+
182+ 2 . ** Sharded on Non-Leading Dimension:** ` jax.P(None, 'x', None) `
183+ If ` x ` is sharded along the second dimension, ` x[0] ` results in shape `(3,
184+ 4)` with partition spec ` jax.P('x', None)`. Since the first dimension (the
185+ one being sliced) is unsharded, this operation also requires ** no
186+ communication** .
187+
188+ 3 . ** Sharded on Leading Dimension:** ` jax.P('x', None, None) `
189+ If ` x ` is sharded along the first dimension, ` x[0] ` results in shape `(3,
190+ 4)` with partition spec ` jax.P(None, None)`.
191+ * ** The Issue:** Because the first dimension is sharded, the data for
192+ ` x[0] ` physically resides * only* on the first device. To satisfy the
193+ output sharding ` jax.P(None, None) ` (which implies replication), JAX
194+ must broadcast the data from the first device to all other devices. This
195+ requires ** communication** ; JAX will gather the * entire* array of shape
196+ ` (8, 3, 4) ` to each device and then take a slice.
197+
198+ ### The Common Performance Pitfall
199+
200+ A common pattern among ` jax.pmap ` users involves arrays that are ** semantically
201+ replicated** (the user intends for them to be identical everywhere) but are
202+ ** physically sharded** (stacked along the leading dimension).
203+
204+ This happens implicitly (e.g., via ` jax.pmap(..., out_axes=0) ` ) or explicitly
205+ (e.g., via ` jax.device_put_replicated ` ). Users often try to retrieve metrics or
206+ checkpoints by calling ` unreplicate ` or ` x[0] ` , assuming it is a cheap
207+ operation.
208+
209+ #### Example: The "Unreplicate" Anti-Pattern
210+
211+ ``` python
212+ from flax import jax_utils
213+ import jax.numpy as jnp
214+ import jax
215+
216+ # jax_utils.replicate calls jax.device_put_replicated.
217+ # This stacks num_devices copies and SHARDS them over the stacked dimension.
218+ # Logical Shape: (8, 3, 4) | Sharding: P('x', None, None)
219+ train_state = jax_utils.replicate({' params' : jnp.zeros((3 , 4 ))})
220+
221+ # out_axes=0 by default, so the output remains sharded along dim 0.
222+ train_step_pmapped = jax.pmap(lambda x : x)
223+
224+ # jax_utils.unreplicate performs a jax.tree_map(lambda x: x[0], tree).
225+ # Users do this to grab metrics, log param statistics, checkpoint, etc.
226+ train_state = jax_utils.unreplicate(train_step_pmapped(train_state))
227+ ```
228+
229+ #### The Consequence
230+ Even though the user knows ` train_state ` contains identical data on every
231+ device, JAX sees an array with ` shape (8, 3, 4) ` and spec `jax.P('x', None,
232+ None)` i.e., an array that is sharded along its leading dimension. JAX cannot
233+ safely assume the data is identical on each device. Therefore, ` x[0] ` triggers a
234+ gather of the entire array to all devices before slicing to ensure correctness.
235+ This unnecessary communication causes performance degradation (visible as
236+ _ gather operations in a stack trace).
237+
238+ ```
239+ train
240+ └─ jax_utils.py:48 unreplicate
241+ └─ tree_util.py:354 tree_map
242+ └─ jax_utils.py:50 <lambda> (performing x[0])
243+ └─ array.py:335 __getitem__
244+ └─ indexing.py:734 rewriting_take
245+ │
246+ ▼
247+ └─ indexing.py:784 _gather
248+ └─ slicing.py:324 gather
249+ └─ PjitFunction(gather)
250+ ```
251+
252+ ### Why was "Old Pmap" Fast?
253+ Historically, ` pmap ` used ` PmapSharding ` , which had a fast-path optimization in
254+ ` jax.Array ` 's ` __getitem__ ` allowing it to return an array with a
255+ ` SingleDeviceSharding ` (data residing on only one device).
256+
257+ However, current JAX uses ` NamedSharding ` . We do not strictly replicate the
258+ legacy behavior because it breaks the semantics of array indexing. If we allowed
259+ ` x[0] ` to return a ` SingleDeviceSharding ` array in a general context (e.g., in
260+ the middle of a train step instead of when trying to bring data back to host for
261+ reporting), only one device would have data while others would have nothing.
262+ This is computationally problematic for subsequent operations.
263+
264+ The slowdown users experience now is JAX enforcing correct semantics: if you ask
265+ for ` x[0] ` from an array sharded along its leading dimension, you get a fully
266+ replicated result available on all devices, which requires communication.
267+
135268<!-- * freshness: { reviewed: '2025-09-29' } *-->
0 commit comments