|
69 | 69 | "outputs": [], |
70 | 70 | "source": [ |
71 | 71 | "# Create an auto-mode mesh of two dimensions and annotate each axis with a name.\n", |
72 | | - "auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))" |
| 72 | + "auto_mesh = jax.make_mesh(\n", |
| 73 | + " (2, 4),\n", |
| 74 | + " ('data', 'model'),\n", |
| 75 | + " axis_types=(AxisType.Auto, AxisType.Auto),\n", |
| 76 | + ")" |
73 | 77 | ] |
74 | 78 | }, |
75 | 79 | { |
|
203 | 207 | "source": [ |
204 | 208 | "### Initialize with style\n", |
205 | 209 | "\n", |
206 | | - "When using existing modules, you can apply [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) on initializers to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n", |
| 210 | + "When using existing modules, you can use `kernel_metadata` and `bias_metadata` arguments to achieve the same effect. Here we create a sharded `nnx.Linear` module with only the kernel weight.\n", |
207 | 211 | "\n", |
208 | 212 | "Also, you should use `jax.jit` for the whole initialization for maximum performance. This is because without `jax.jit`, a single-device variable must be created first before we apply sharding constraints and then make it sharded, which is wasteful. `jax.jit` will automatically optimize this out." |
209 | 213 | ] |
|
216 | 220 | "source": [ |
217 | 221 | "@jax.jit\n", |
218 | 222 | "def init_sharded_linear(key):\n", |
219 | | - " init_fn = nnx.nn.linear.default_kernel_init\n", |
220 | 223 | " # Shard your parameter along `model` dimension, as in model/tensor parallelism\n", |
221 | 224 | " return nnx.Linear(4, 8, use_bias=False, rngs=nnx.Rngs(key),\n", |
222 | | - " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')))\n", |
| 225 | + " kernel_metadata={'sharding_names': (None, 'model')})\n", |
223 | 226 | "\n", |
224 | 227 | "with jax.set_mesh(auto_mesh):\n", |
225 | 228 | " key= rngs()\n", |
|
328 | 331 | " init_fn = nnx.initializers.lecun_normal()\n", |
329 | 332 | " self.dot1 = nnx.Linear(\n", |
330 | 333 | " depth, depth,\n", |
331 | | - " kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),\n", |
332 | | - " use_bias=False, # or use `bias_init` to give it annotation too\n", |
| 334 | + " kernel_metadata={'sharding_names': (None, 'model')},\n", |
| 335 | + " use_bias=False, # or use `bias_metadata` to give it annotation too\n", |
333 | 336 | " rngs=rngs)\n", |
334 | 337 | " self.w2 = nnx.Param(\n", |
335 | 338 | " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", |
336 | | - " sharding=('model', None),\n", |
| 339 | + " sharding=('model', None), # same as sharding_names=('model', None)\n", |
337 | 340 | " )\n", |
338 | 341 | "\n", |
339 | 342 | " def __call__(self, x: jax.Array):\n", |
|
512 | 515 | " init_fn = nnx.initializers.lecun_normal()\n", |
513 | 516 | " self.dot1 = nnx.Linear(\n", |
514 | 517 | " depth, depth,\n", |
515 | | - " kernel_init=nnx.with_partitioning(init_fn, ('embed', 'hidden')),\n", |
516 | | - " use_bias=False, # or use `bias_init` to give it annotation too\n", |
| 518 | + " kernel_metadata={'sharding_names': ('embed', 'hidden')},\n", |
| 519 | + " use_bias=False, # or use `bias_metadata` to give it annotation too\n", |
517 | 520 | " rngs=rngs)\n", |
518 | 521 | " self.w2 = nnx.Param(\n", |
519 | 522 | " init_fn(rngs.params(), (depth, depth)), # RNG key and shape for W2 creation\n", |
|
0 commit comments