Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@
"outputs": [],
"source": [
"# Create an auto-mode mesh of two dimensions and annotate each axis with a name.\n",
"auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))"
"auto_mesh = jax.make_mesh(\n",
" (2, 4),\n",
" ('data', 'model'),\n",
" axis_types=(AxisType.Auto, AxisType.Auto),\n",
")"
]
},
{
Expand Down Expand Up @@ -203,7 +207,7 @@
"source": [
"### Initialize with style\n",
"\n",
"When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n",
"When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n",
"\n",
"Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out."
]
Expand All @@ -216,10 +220,9 @@
"source": [
"@jax.jit\n",
"def init_sharded_linear(key):\n",
" init_fn = nnx.nn.linear.default_kernel_init\n",
" # Shard your parameter along `model` dimension, as in model/tensor parallelism\n",
" return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),\n",
" kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))\n",
" kernel_metadata={'sharding_names': (None, 'model')})\n",
"\n",
"with jax.set_mesh(auto_mesh):\n",
" key= rngs()\n",
Expand Down Expand Up @@ -328,12 +331,12 @@
" init_fn = nnx.initializers.lecun_normal()\n",
" self.dot1 = nnx.Linear(\n",
" depth, depth,\n",
" kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n",
" use_bias=False, # or use `bias_init` to give it annotation too\n",
" kernel_metadata={'sharding_names': (None, 'model')},\n",
" use_bias=False, # or use `bias_metadata` to give it annotation too\n",
" rngs=rngs)\n",
" self.w2 = nnx.Param(\n",
" init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n",
" sharding=('model', None),\n",
" sharding=('model', None), # same as sharding_names=('model', None)\n",
" )\n",
"\n",
" def __call__(self, x: jax.Array):\n",
Expand Down Expand Up @@ -512,8 +515,8 @@
" init_fn = nnx.initializers.lecun_normal()\n",
" self.dot1 = nnx.Linear(\n",
" depth, depth,\n",
" kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),\n",
" use_bias=False, # or use `bias_init` to give it annotation too\n",
" kernel_metadata={'sharding_names': ('embed', 'hidden')},\n",
" use_bias=False, # or use `bias_metadata` to give it annotation too\n",
" rngs=rngs)\n",
" self.w2 = nnx.Param(\n",
" init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n",
Expand Down
21 changes: 12 additions & 9 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ In this guide we use a standard FSDP layout and shard our devices on two axes -

```{code-cell} ipython3
# Create an auto-mode mesh of two dimensions and annotate each axis with a name.
auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))
auto_mesh = jax.make_mesh(
(2, 4),
('data', 'model'),
axis_types=(AxisType.Auto, AxisType.Auto),
)
```

> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function.
Expand Down Expand Up @@ -89,17 +93,16 @@ with jax.set_mesh(auto_mesh):

### Initialize with style

When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.
When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.

Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out.

```{code-cell} ipython3
@jax.jit
def init_sharded_linear(key):
init_fn = nnx.nn.linear.default_kernel_init
# Shard your parameter along `model` dimension, as in model/tensor parallelism
return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),
kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))
kernel_metadata={'sharding_names': (None, 'model')})

with jax.set_mesh(auto_mesh):
key= rngs()
Expand Down Expand Up @@ -144,12 +147,12 @@ class DotReluDot(nnx.Module):
init_fn = nnx.initializers.lecun_normal()
self.dot1 = nnx.Linear(
depth, depth,
kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
use_bias=False, # or use `bias_init` to give it annotation too
kernel_metadata={'sharding_names': (None, 'model')},
use_bias=False, # or use `bias_metadata` to give it annotation too
rngs=rngs)
self.w2 = nnx.Param(
init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation
sharding=('model', None),
sharding=('model', None), # same as sharding_names=('model', None)
)

def __call__(self, x: jax.Array):
Expand Down Expand Up @@ -258,8 +261,8 @@ class LogicalDotReluDot(nnx.Module):
init_fn = nnx.initializers.lecun_normal()
self.dot1 = nnx.Linear(
depth, depth,
kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),
use_bias=False, # or use `bias_init` to give it annotation too
kernel_metadata={'sharding_names': ('embed', 'hidden')},
use_bias=False, # or use `bias_metadata` to give it annotation too
rngs=rngs)
self.w2 = nnx.Param(
init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation
Expand Down
Loading