Skip to content

Commit b1b61b6

Browse files
committed
Updated flax_gspmd to use kernel_metadata
1 parent 7710c30 commit b1b61b6

File tree

2 files changed

+24
-18
lines changed

2 files changed

+24
-18
lines changed

docs_nnx/guides/flax_gspmd.ipynb

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,11 @@
6969
"outputs": [],
7070
"source": [
7171
"# Create an auto-mode mesh of two dimensions and annotate each axis with a name.\n",
72-
"auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))"
72+
"auto_mesh = jax.make_mesh(\n",
73+
" (2, 4),\n",
74+
" ('data', 'model'),\n",
75+
" axis_types=(AxisType.Auto, AxisType.Auto),\n",
76+
")"
7377
]
7478
},
7579
{
@@ -203,7 +207,7 @@
203207
"source": [
204208
"### Initialize with style\n",
205209
"\n",
206-
"When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n",
210+
"When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n",
207211
"\n",
208212
"Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out."
209213
]
@@ -216,10 +220,9 @@
216220
"source": [
217221
"@jax.jit\n",
218222
"def init_sharded_linear(key):\n",
219-
" init_fn = nnx.nn.linear.default_kernel_init\n",
220223
" # Shard your parameter along `model` dimension, as in model/tensor parallelism\n",
221224
" return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),\n",
222-
" kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))\n",
225+
" kernel_metadata={'sharding_names': (None, 'model')})\n",
223226
"\n",
224227
"with jax.set_mesh(auto_mesh):\n",
225228
" key= rngs()\n",
@@ -328,12 +331,12 @@
328331
" init_fn = nnx.initializers.lecun_normal()\n",
329332
" self.dot1 = nnx.Linear(\n",
330333
" depth, depth,\n",
331-
" kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n",
332-
" use_bias=False, # or use `bias_init` to give it annotation too\n",
334+
" kernel_metadata={'sharding_names': (None, 'model')},\n",
335+
" use_bias=False, # or use `bias_metadata` to give it annotation too\n",
333336
" rngs=rngs)\n",
334337
" self.w2 = nnx.Param(\n",
335338
" init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n",
336-
" sharding=('model', None),\n",
339+
" sharding=('model', None), # same as sharding_names=('model', None)\n",
337340
" )\n",
338341
"\n",
339342
" def __call__(self, x: jax.Array):\n",
@@ -512,8 +515,8 @@
512515
" init_fn = nnx.initializers.lecun_normal()\n",
513516
" self.dot1 = nnx.Linear(\n",
514517
" depth, depth,\n",
515-
" kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),\n",
516-
" use_bias=False, # or use `bias_init` to give it annotation too\n",
518+
" kernel_metadata={'sharding_names': ('embed', 'hidden')},\n",
519+
" use_bias=False, # or use `bias_metadata` to give it annotation too\n",
517520
" rngs=rngs)\n",
518521
" self.w2 = nnx.Param(\n",
519522
" init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n",

docs_nnx/guides/flax_gspmd.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ In this guide we use a standard FSDP layout and shard our devices on two axes -
4444

4545
```{code-cell} ipython3
4646
# Create an auto-mode mesh of two dimensions and annotate each axis with a name.
47-
auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))
47+
auto_mesh = jax.make_mesh(
48+
(2, 4),
49+
('data', 'model'),
50+
axis_types=(AxisType.Auto, AxisType.Auto),
51+
)
4852
```
4953

5054
> Compatibility Note: This guide covers the [eager sharding feature](https://flax.readthedocs.io/en/latest/flip/4844-var-eager-sharding.html) that greatly simplifies creating sharded model. If your project already used Flax GSPMD API on version `flax<0.12`, you might have turned the feature off to keep your code working. Users can toggle this feature using the `nnx.use_eager_sharding` function.
@@ -89,17 +93,16 @@ with jax.set_mesh(auto_mesh):
8993

9094
### Initialize with style
9195

92-
When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.
96+
When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.
9397

9498
Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out.
9599

96100
```{code-cell} ipython3
97101
@jax.jit
98102
def init_sharded_linear(key):
99-
init_fn = nnx.nn.linear.default_kernel_init
100103
# Shard your parameter along `model` dimension, as in model/tensor parallelism
101104
return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),
102-
kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))
105+
kernel_metadata={'sharding_names': (None, 'model')})
103106
104107
with jax.set_mesh(auto_mesh):
105108
key= rngs()
@@ -144,12 +147,12 @@ class DotReluDot(nnx.Module):
144147
init_fn = nnx.initializers.lecun_normal()
145148
self.dot1 = nnx.Linear(
146149
depth, depth,
147-
kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
148-
use_bias=False, # or use `bias_init` to give it annotation too
150+
kernel_metadata={'sharding_names': (None, 'model')},
151+
use_bias=False, # or use `bias_metadata` to give it annotation too
149152
rngs=rngs)
150153
self.w2 = nnx.Param(
151154
init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation
152-
sharding=('model', None),
155+
sharding=('model', None), # same as sharding_names=('model', None)
153156
)
154157
155158
def __call__(self, x: jax.Array):
@@ -258,8 +261,8 @@ class LogicalDotReluDot(nnx.Module):
258261
init_fn = nnx.initializers.lecun_normal()
259262
self.dot1 = nnx.Linear(
260263
depth, depth,
261-
kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),
262-
use_bias=False, # or use `bias_init` to give it annotation too
264+
kernel_metadata={'sharding_names': ('embed', 'hidden')},
265+
use_bias=False, # or use `bias_metadata` to give it annotation too
263266
rngs=rngs)
264267
self.w2 = nnx.Param(
265268
init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation

0 commit comments

Comments
 (0)