diff --git a/docs/migrate_pmap.md b/docs/migrate_pmap.md index d48aa8fb28cc..be0080577c7e 100644 --- a/docs/migrate_pmap.md +++ b/docs/migrate_pmap.md @@ -92,6 +92,49 @@ Mesh('y': 4, axis_types=(Auto,)) ## Performance implications +### `int` indexing into sharded arrays + +The new implementation of `jax.pmap` uses `NamedSharding` instead of the legacy +`PmapSharding`. We've observe a common pattern with the old `jax.pmap` where +users shard stacked copies of an array to replicate (e.g., via +`jax.device_put_replicated`). These "sharded-but-really-replicated" arrays +suffer unnecessary communication overhead when `int` indexing (e.g., `x[0]`) +because JAX does not know the arrays are actually replicated. For a more +thorough discussion, please see [Appendix A](#appendix-a). + +#### Option 1: Prevent unintended sharding (recommended) +Avoid creating the leading sharded dimension entirely. + +- Use `jax.pmap`'s `out_axes=None` for arguments that should remain replicated. +The output will be fully replicated (e.g., `P(None, None)`), making access +cheap. +- For inputs: When using `jax.device_put`, specify `jax.P()` (fully replicated) +in the partition spec rather than relying on utilities that stack and shard. +(Note: `jax.device_put_replicated` and `jax.device_put_sharded` are deprecated +because they confusingly produce sharded arrays rather than replicated ones). + +#### Option 2: Access local data directly +If you must work with a sharded array (or want potentially fewer changes to +code), you can access the local data shard directly without triggering JAX's +distributed consistency checks. Note that this is only recommended when bringing +data back to host (e.g., for logging, checkpointing). Instead of `x[0]`, use +`addressable_shards`: + +```python +# Old slow way: +# result = x[0] + +# New fast way: +# x.addressable_shards is a list of shards on the current process. +# We grab the first one, extract the data, and remove the leading dimension. +result = x.addressable_shards[0].data.squeeze(0) +``` + +In the example of `x` with shape `(8, 3, 4)`, `x.addressable_shards[0].data` +returns the local chunk of shape `(1, 3, 4)`. Calling `.squeeze(0)` results in +the desired `(3, 4)` shape without any cross-device communication. Both +solutions will eliminate the `_gather` operations seen in profiling. + ### Host local array to global array round-trip conversion In multi-process JAX programs (i.e., `jax.process_count() > 1`), arrays might be @@ -104,23 +147,6 @@ host-local array when returning to user code. This round-trip conversion cannot be avoided, so if the performance penalty is too great, we recommend migrating your code to `jax.shard_map`. -### `int` array indexing - -Indexing into a sharded array with an int (e.g., `arr[0]`) may now execute a -rank reduction computation. Depending on your use case, there may be -workarounds: - -1. In a typical training loop, we might use a `jax.pmap`ed update function to - operate on / carry training state and grab resulting metrics from the first - `jax.pmap`'ed device for logging. In this case, it may be possible to - use `None` for the relevant `in_axes` and `out_axes` passed to `jax.pmap`. - This lets `jax.pmap` handle replication and will return an - appropriately-shaped result that looks like it's from a single device for, - say, logging metrics. -2. More generally, you can get the first shard of data without a reshape via - `arr[0:1]` or `arr.addressable_shards[0].data`. Note that this will have a - leading `(1,)` dimension that your code will need to handle. - ## Migrating to `jax.shard_map` In many cases, users can migrate from `jax.pmap` to `jax.jit(jax.shard_map)` by @@ -132,4 +158,111 @@ dispatch path as in the `jax.shard_map` implementation of `jax.pmap` and can often be overlapped with compute or be called infrequently (i.e., before a train loop and for occasionally grabbing metrics). +(appendix-a)= +## Appendix A: More details about `int` indexing into sharded arrays. + +### What should `x[0]` return? + +In **NumPy**, `x[0]` returns a rank-reduced array representing the first slice +along the first dimension. For example, if `x = np.ones((8, 3, 4))`, then `x[0]` +returns an array of shape `(3, 4)`. + +In **JAX** (`jax.numpy`), `x[0]` semantically works the same way: it returns the +rank-reduced slice of the logical array `x`. However, performance depends on how +`x` is sharded or replicated across devices. Consider an array `x` with shape +`(8, 3, 4)` distributed across 8 devices (using `jax.P` as the short name for +`jax.sharding.PartitionSpec`P): + +1. **Fully Replicated:** `jax.P(None, None, None)` + If `x` is fully replicated, every device holds a complete copy of the `(8, + 3, 4)` array. `x[0]` will have the shape `(3, 4)` and a partition spec + `jax.P(None, None)`. Since every device already has `x`, this operation will + slice on each device independently and requires **no communication**. + +2. **Sharded on Non-Leading Dimension:** `jax.P(None, 'x', None)` + If `x` is sharded along the second dimension, `x[0]` results in shape `(3, + 4)` with partition spec `jax.P('x', None)`. Since the first dimension (the + one being sliced) is unsharded, this operation also requires **no + communication**. + +3. **Sharded on Leading Dimension:** `jax.P('x', None, None)` + If `x` is sharded along the first dimension, `x[0]` results in shape `(3, + 4)` with partition spec `jax.P(None, None)`. + * **The Issue:** Because the first dimension is sharded, the data for + `x[0]` physically resides *only* on the first device. To satisfy the + output sharding `jax.P(None, None)` (which implies replication), JAX + must broadcast the data from the first device to all other devices. This + requires **communication**; JAX will gather the *entire* array of shape + `(8, 3, 4)` to each device and then take a slice. + +### The Common Performance Pitfall + +A common pattern among `jax.pmap` users involves arrays that are **semantically +replicated** (the user intends for them to be identical everywhere) but are +**physically sharded** (stacked along the leading dimension). + +This happens implicitly (e.g., via `jax.pmap(..., out_axes=0)`) or explicitly +(e.g., via `jax.device_put_replicated`). Users often try to retrieve metrics or +checkpoints by calling `unreplicate` or `x[0]`, assuming it is a cheap +operation. + +#### Example: The "Unreplicate" Anti-Pattern + +```python +from flax import jax_utils +import jax.numpy as jnp +import jax + +# jax_utils.replicate calls jax.device_put_replicated. +# This stacks num_devices copies and SHARDS them over the stacked dimension. +# Logical Shape: (8, 3, 4) | Sharding: P('x', None, None) +train_state = jax_utils.replicate({'params': jnp.zeros((3, 4))}) + +# out_axes=0 by default, so the output remains sharded along dim 0. +train_step_pmapped = jax.pmap(lambda x: x) + +# jax_utils.unreplicate performs a jax.tree_map(lambda x: x[0], tree). +# Users do this to grab metrics, log param statistics, checkpoint, etc. +train_state = jax_utils.unreplicate(train_step_pmapped(train_state)) +``` + +#### The Consequence +Even though the user knows `train_state` contains identical data on every +device, JAX sees an array with `shape (8, 3, 4)` and spec `jax.P('x', None, +None)` i.e., an array that is sharded along its leading dimension. JAX cannot +safely assume the data is identical on each device. Therefore, `x[0]` triggers a +gather of the entire array to all devices before slicing to ensure correctness. +This unnecessary communication causes performance degradation (visible as +_gather operations in a stack trace). + +``` +train + └─ jax_utils.py:48 unreplicate + └─ tree_util.py:354 tree_map + └─ jax_utils.py:50 (performing x[0]) + └─ array.py:335 __getitem__ + └─ indexing.py:734 rewriting_take + │ + ▼ + └─ indexing.py:784 _gather + └─ slicing.py:324 gather + └─ PjitFunction(gather) +``` + +### Why was "Old Pmap" Fast? +Historically, `pmap` used `PmapSharding`, which had a fast-path optimization in +`jax.Array`'s `__getitem__` allowing it to return an array with a +`SingleDeviceSharding` (data residing on only one device). + +However, current JAX uses `NamedSharding`. We do not strictly replicate the +legacy behavior because it breaks the semantics of array indexing. If we allowed +`x[0]` to return a `SingleDeviceSharding` array in a general context (e.g., in +the middle of a train step instead of when trying to bring data back to host for +reporting), only one device would have data while others would have nothing. +This is computationally problematic for subsequent operations. + +The slowdown users experience now is JAX enforcing correct semantics: if you ask +for `x[0]` from an array sharded along its leading dimension, you get a fully +replicated result available on all devices, which requires communication. +