Skip to content

Commit d5df192

Browse files
danielsuoGoogle-ML-Automation
authored andcommitted
[pmap] Add more detailed documentation about int array indexing in JAX.
PiperOrigin-RevId: 845080355
1 parent 6395b5f commit d5df192

File tree

1 file changed

+150
-17
lines changed

1 file changed

+150
-17
lines changed

docs/migrate_pmap.md

Lines changed: 150 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,49 @@ Mesh('y': 4, axis_types=(Auto,))
9292

9393
## Performance implications
9494

95+
### `int` indexing into sharded arrays
96+
97+
The new implementation of `jax.pmap` uses `NamedSharding` instead of the legacy
98+
`PmapSharding`. We've observe a common pattern with the old `jax.pmap` where
99+
users shard stacked copies of an array to replicate (e.g., via
100+
`jax.device_put_replicated`). These "sharded-but-really-replicated" arrays
101+
suffer unnecessary communication overhead when `int` indexing (e.g., `x[0]`)
102+
because JAX does not know the arrays are actually replicated. For a more
103+
thorough discussion, please see [Appendix A](#appendix-a).
104+
105+
#### Option 1: Prevent unintended sharding (recommended)
106+
Avoid creating the leading sharded dimension entirely.
107+
108+
- Use `jax.pmap`'s `out_axes=None` for arguments that should remain replicated.
109+
The output will be fully replicated (e.g., `P(None, None)`), making access
110+
cheap.
111+
- For inputs: When using `jax.device_put`, specify `jax.P()` (fully replicated)
112+
in the partition spec rather than relying on utilities that stack and shard.
113+
(Note: `jax.device_put_replicated` and `jax.device_put_sharded` are deprecated
114+
because they confusingly produce sharded arrays rather than replicated ones).
115+
116+
#### Option 2: Access local data directly
117+
If you must work with a sharded array (or want potentially fewer changes to
118+
code), you can access the local data shard directly without triggering JAX's
119+
distributed consistency checks. Note that this is only recommended when bringing
120+
data back to host (e.g., for logging, checkpointing). Instead of `x[0]`, use
121+
`addressable_shards`:
122+
123+
```python
124+
# Old slow way:
125+
# result = x[0]
126+
127+
# New fast way:
128+
# x.addressable_shards is a list of shards on the current process.
129+
# We grab the first one, extract the data, and remove the leading dimension.
130+
result = x.addressable_shards[0].data.squeeze(0)
131+
```
132+
133+
In the example of `x` with shape `(8, 3, 4)`, `x.addressable_shards[0].data`
134+
returns the local chunk of shape `(1, 3, 4)`. Calling `.squeeze(0)` results in
135+
the desired `(3, 4)` shape without any cross-device communication. Both
136+
solutions will eliminate the `_gather` operations seen in profiling.
137+
95138
### Host local array to global array round-trip conversion
96139

97140
In multi-process JAX programs (i.e., `jax.process_count() > 1`), arrays might be
@@ -104,23 +147,6 @@ host-local array when returning to user code.
104147
This round-trip conversion cannot be avoided, so if the performance penalty is
105148
too great, we recommend migrating your code to `jax.shard_map`.
106149

107-
### `int` array indexing
108-
109-
Indexing into a sharded array with an int (e.g., `arr[0]`) may now execute a
110-
rank reduction computation. Depending on your use case, there may be
111-
workarounds:
112-
113-
1. In a typical training loop, we might use a `jax.pmap`ed update function to
114-
operate on / carry training state and grab resulting metrics from the first
115-
`jax.pmap`'ed device for logging. In this case, it may be possible to
116-
use `None` for the relevant `in_axes` and `out_axes` passed to `jax.pmap`.
117-
This lets `jax.pmap` handle replication and will return an
118-
appropriately-shaped result that looks like it's from a single device for,
119-
say, logging metrics.
120-
2. More generally, you can get the first shard of data without a reshape via
121-
`arr[0:1]` or `arr.addressable_shards[0].data`. Note that this will have a
122-
leading `(1,)` dimension that your code will need to handle.
123-
124150
## Migrating to `jax.shard_map`
125151

126152
In many cases, users can migrate from `jax.pmap` to `jax.jit(jax.shard_map)` by
@@ -132,4 +158,111 @@ dispatch path as in the `jax.shard_map` implementation of `jax.pmap` and can
132158
often be overlapped with compute or be called infrequently (i.e., before a train
133159
loop and for occasionally grabbing metrics).
134160

161+
(appendix-a)=
162+
## Appendix A: More details about `int` indexing into sharded arrays.
163+
164+
### What should `x[0]` return?
165+
166+
In **NumPy**, `x[0]` returns a rank-reduced array representing the first slice
167+
along the first dimension. For example, if `x = np.ones((8, 3, 4))`, then `x[0]`
168+
returns an array of shape `(3, 4)`.
169+
170+
In **JAX** (`jax.numpy`), `x[0]` semantically works the same way: it returns the
171+
rank-reduced slice of the logical array `x`. However, performance depends on how
172+
`x` is sharded or replicated across devices. Consider an array `x` with shape
173+
`(8, 3, 4)` distributed across 8 devices (using `jax.P` as the short name for
174+
`jax.sharding.PartitionSpec`P):
175+
176+
1. **Fully Replicated:** `jax.P(None, None, None)`
177+
If `x` is fully replicated, every device holds a complete copy of the `(8,
178+
3, 4)` array. `x[0]` will have the shape `(3, 4)` and a partition spec
179+
`jax.P(None, None)`. Since every device already has `x`, this operation will
180+
slice on each device independently and requires **no communication**.
181+
182+
2. **Sharded on Non-Leading Dimension:** `jax.P(None, 'x', None)`
183+
If `x` is sharded along the second dimension, `x[0]` results in shape `(3,
184+
4)` with partition spec `jax.P('x', None)`. Since the first dimension (the
185+
one being sliced) is unsharded, this operation also requires **no
186+
communication**.
187+
188+
3. **Sharded on Leading Dimension:** `jax.P('x', None, None)`
189+
If `x` is sharded along the first dimension, `x[0]` results in shape `(3,
190+
4)` with partition spec `jax.P(None, None)`.
191+
* **The Issue:** Because the first dimension is sharded, the data for
192+
`x[0]` physically resides *only* on the first device. To satisfy the
193+
output sharding `jax.P(None, None)` (which implies replication), JAX
194+
must broadcast the data from the first device to all other devices. This
195+
requires **communication**; JAX will gather the *entire* array of shape
196+
`(8, 3, 4)` to each device and then take a slice.
197+
198+
### The Common Performance Pitfall
199+
200+
A common pattern among `jax.pmap` users involves arrays that are **semantically
201+
replicated** (the user intends for them to be identical everywhere) but are
202+
**physically sharded** (stacked along the leading dimension).
203+
204+
This happens implicitly (e.g., via `jax.pmap(..., out_axes=0)`) or explicitly
205+
(e.g., via `jax.device_put_replicated`). Users often try to retrieve metrics or
206+
checkpoints by calling `unreplicate` or `x[0]`, assuming it is a cheap
207+
operation.
208+
209+
#### Example: The "Unreplicate" Anti-Pattern
210+
211+
```python
212+
from flax import jax_utils
213+
import jax.numpy as jnp
214+
import jax
215+
216+
# jax_utils.replicate calls jax.device_put_replicated.
217+
# This stacks num_devices copies and SHARDS them over the stacked dimension.
218+
# Logical Shape: (8, 3, 4) | Sharding: P('x', None, None)
219+
train_state = jax_utils.replicate({'params': jnp.zeros((3, 4))})
220+
221+
# out_axes=0 by default, so the output remains sharded along dim 0.
222+
train_step_pmapped = jax.pmap(lambda x: x)
223+
224+
# jax_utils.unreplicate performs a jax.tree_map(lambda x: x[0], tree).
225+
# Users do this to grab metrics, log param statistics, checkpoint, etc.
226+
train_state = jax_utils.unreplicate(train_step_pmapped(train_state))
227+
```
228+
229+
#### The Consequence
230+
Even though the user knows `train_state` contains identical data on every
231+
device, JAX sees an array with `shape (8, 3, 4)` and spec `jax.P('x', None,
232+
None)` i.e., an array that is sharded along its leading dimension. JAX cannot
233+
safely assume the data is identical on each device. Therefore, `x[0]` triggers a
234+
gather of the entire array to all devices before slicing to ensure correctness.
235+
This unnecessary communication causes performance degradation (visible as
236+
_gather operations in a stack trace).
237+
238+
```
239+
train
240+
└─ jax_utils.py:48 unreplicate
241+
└─ tree_util.py:354 tree_map
242+
└─ jax_utils.py:50 <lambda> (performing x[0])
243+
└─ array.py:335 __getitem__
244+
└─ indexing.py:734 rewriting_take
245+
246+
247+
└─ indexing.py:784 _gather
248+
└─ slicing.py:324 gather
249+
└─ PjitFunction(gather)
250+
```
251+
252+
### Why was "Old Pmap" Fast?
253+
Historically, `pmap` used `PmapSharding`, which had a fast-path optimization in
254+
`jax.Array`'s `__getitem__` allowing it to return an array with a
255+
`SingleDeviceSharding` (data residing on only one device).
256+
257+
However, current JAX uses `NamedSharding`. We do not strictly replicate the
258+
legacy behavior because it breaks the semantics of array indexing. If we allowed
259+
`x[0]` to return a `SingleDeviceSharding` array in a general context (e.g., in
260+
the middle of a train step instead of when trying to bring data back to host for
261+
reporting), only one device would have data while others would have nothing.
262+
This is computationally problematic for subsequent operations.
263+
264+
The slowdown users experience now is JAX enforcing correct semantics: if you ask
265+
for `x[0]` from an array sharded along its leading dimension, you get a fully
266+
replicated result available on all devices, which requires communication.
267+
135268
<!--* freshness: { reviewed: '2025-09-29' } *-->

0 commit comments

Comments
 (0)