Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 150 additions & 17 deletions docs/migrate_pmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,49 @@ Mesh('y': 4, axis_types=(Auto,))

## Performance implications

### `int` indexing into sharded arrays

The new implementation of `jax.pmap` uses `NamedSharding` instead of the legacy
`PmapSharding`. We've observe a common pattern with the old `jax.pmap` where
users shard stacked copies of an array to replicate (e.g., via
`jax.device_put_replicated`). These "sharded-but-really-replicated" arrays
suffer unnecessary communication overhead when `int` indexing (e.g., `x[0]`)
because JAX does not know the arrays are actually replicated. For a more
thorough discussion, please see [Appendix A](#appendix-a).

#### Option 1: Prevent unintended sharding (recommended)
Avoid creating the leading sharded dimension entirely.

- Use `jax.pmap`'s `out_axes=None` for arguments that should remain replicated.
The output will be fully replicated (e.g., `P(None, None)`), making access
cheap.
- For inputs: When using `jax.device_put`, specify `jax.P()` (fully replicated)
in the partition spec rather than relying on utilities that stack and shard.
(Note: `jax.device_put_replicated` and `jax.device_put_sharded` are deprecated
because they confusingly produce sharded arrays rather than replicated ones).

#### Option 2: Access local data directly
If you must work with a sharded array (or want potentially fewer changes to
code), you can access the local data shard directly without triggering JAX's
distributed consistency checks. Note that this is only recommended when bringing
data back to host (e.g., for logging, checkpointing). Instead of `x[0]`, use
`addressable_shards`:

```python
# Old slow way:
# result = x[0]

# New fast way:
# x.addressable_shards is a list of shards on the current process.
# We grab the first one, extract the data, and remove the leading dimension.
result = x.addressable_shards[0].data.squeeze(0)
```

In the example of `x` with shape `(8, 3, 4)`, `x.addressable_shards[0].data`
returns the local chunk of shape `(1, 3, 4)`. Calling `.squeeze(0)` results in
the desired `(3, 4)` shape without any cross-device communication. Both
solutions will eliminate the `_gather` operations seen in profiling.

### Host local array to global array round-trip conversion

In multi-process JAX programs (i.e., `jax.process_count() > 1`), arrays might be
Expand All @@ -104,23 +147,6 @@ host-local array when returning to user code.
This round-trip conversion cannot be avoided, so if the performance penalty is
too great, we recommend migrating your code to `jax.shard_map`.

### `int` array indexing

Indexing into a sharded array with an int (e.g., `arr[0]`) may now execute a
rank reduction computation. Depending on your use case, there may be
workarounds:

1. In a typical training loop, we might use a `jax.pmap`ed update function to
operate on / carry training state and grab resulting metrics from the first
`jax.pmap`'ed device for logging. In this case, it may be possible to
use `None` for the relevant `in_axes` and `out_axes` passed to `jax.pmap`.
This lets `jax.pmap` handle replication and will return an
appropriately-shaped result that looks like it's from a single device for,
say, logging metrics.
2. More generally, you can get the first shard of data without a reshape via
`arr[0:1]` or `arr.addressable_shards[0].data`. Note that this will have a
leading `(1,)` dimension that your code will need to handle.

## Migrating to `jax.shard_map`

In many cases, users can migrate from `jax.pmap` to `jax.jit(jax.shard_map)` by
Expand All @@ -132,4 +158,111 @@ dispatch path as in the `jax.shard_map` implementation of `jax.pmap` and can
often be overlapped with compute or be called infrequently (i.e., before a train
loop and for occasionally grabbing metrics).

(appendix-a)=
## Appendix A: More details about `int` indexing into sharded arrays.

### What should `x[0]` return?

In **NumPy**, `x[0]` returns a rank-reduced array representing the first slice
along the first dimension. For example, if `x = np.ones((8, 3, 4))`, then `x[0]`
returns an array of shape `(3, 4)`.

In **JAX** (`jax.numpy`), `x[0]` semantically works the same way: it returns the
rank-reduced slice of the logical array `x`. However, performance depends on how
`x` is sharded or replicated across devices. Consider an array `x` with shape
`(8, 3, 4)` distributed across 8 devices (using `jax.P` as the short name for
`jax.sharding.PartitionSpec`P):

1. **Fully Replicated:** `jax.P(None, None, None)`
If `x` is fully replicated, every device holds a complete copy of the `(8,
3, 4)` array. `x[0]` will have the shape `(3, 4)` and a partition spec
`jax.P(None, None)`. Since every device already has `x`, this operation will
slice on each device independently and requires **no communication**.

2. **Sharded on Non-Leading Dimension:** `jax.P(None, 'x', None)`
If `x` is sharded along the second dimension, `x[0]` results in shape `(3,
4)` with partition spec `jax.P('x', None)`. Since the first dimension (the
one being sliced) is unsharded, this operation also requires **no
communication**.

3. **Sharded on Leading Dimension:** `jax.P('x', None, None)`
If `x` is sharded along the first dimension, `x[0]` results in shape `(3,
4)` with partition spec `jax.P(None, None)`.
* **The Issue:** Because the first dimension is sharded, the data for
`x[0]` physically resides *only* on the first device. To satisfy the
output sharding `jax.P(None, None)` (which implies replication), JAX
must broadcast the data from the first device to all other devices. This
requires **communication**; JAX will gather the *entire* array of shape
`(8, 3, 4)` to each device and then take a slice.

### The Common Performance Pitfall

A common pattern among `jax.pmap` users involves arrays that are **semantically
replicated** (the user intends for them to be identical everywhere) but are
**physically sharded** (stacked along the leading dimension).

This happens implicitly (e.g., via `jax.pmap(..., out_axes=0)`) or explicitly
(e.g., via `jax.device_put_replicated`). Users often try to retrieve metrics or
checkpoints by calling `unreplicate` or `x[0]`, assuming it is a cheap
operation.

#### Example: The "Unreplicate" Anti-Pattern

```python
from flax import jax_utils
import jax.numpy as jnp
import jax

# jax_utils.replicate calls jax.device_put_replicated.
# This stacks num_devices copies and SHARDS them over the stacked dimension.
# Logical Shape: (8, 3, 4) | Sharding: P('x', None, None)
train_state = jax_utils.replicate({'params': jnp.zeros((3, 4))})

# out_axes=0 by default, so the output remains sharded along dim 0.
train_step_pmapped = jax.pmap(lambda x: x)

# jax_utils.unreplicate performs a jax.tree_map(lambda x: x[0], tree).
# Users do this to grab metrics, log param statistics, checkpoint, etc.
train_state = jax_utils.unreplicate(train_step_pmapped(train_state))
```

#### The Consequence
Even though the user knows `train_state` contains identical data on every
device, JAX sees an array with `shape (8, 3, 4)` and spec `jax.P('x', None,
None)` i.e., an array that is sharded along its leading dimension. JAX cannot
safely assume the data is identical on each device. Therefore, `x[0]` triggers a
gather of the entire array to all devices before slicing to ensure correctness.
This unnecessary communication causes performance degradation (visible as
_gather operations in a stack trace).

```
train
└─ jax_utils.py:48 unreplicate
└─ tree_util.py:354 tree_map
└─ jax_utils.py:50 <lambda> (performing x[0])
└─ array.py:335 __getitem__
└─ indexing.py:734 rewriting_take
└─ indexing.py:784 _gather
└─ slicing.py:324 gather
└─ PjitFunction(gather)
```

### Why was "Old Pmap" Fast?
Historically, `pmap` used `PmapSharding`, which had a fast-path optimization in
`jax.Array`'s `__getitem__` allowing it to return an array with a
`SingleDeviceSharding` (data residing on only one device).

However, current JAX uses `NamedSharding`. We do not strictly replicate the
legacy behavior because it breaks the semantics of array indexing. If we allowed
`x[0]` to return a `SingleDeviceSharding` array in a general context (e.g., in
the middle of a train step instead of when trying to bring data back to host for
reporting), only one device would have data while others would have nothing.
This is computationally problematic for subsequent operations.

The slowdown users experience now is JAX enforcing correct semantics: if you ask
for `x[0]` from an array sharded along its leading dimension, you get a fully
replicated result available on all devices, which requires communication.

<!--* freshness: { reviewed: '2025-09-29' } *-->
Loading