@@ -33,7 +33,10 @@ class Embed(nnx.Module):
3333 num_embeddings: number of embeddings.
3434 features: number of feature dimensions for each embedding.
3535 dtype: the dtype of the embedding vectors (default: float32).
36+ param_dtype: the dtype of the embedding parameters.
37+ promote_dtype: the dtype promotion function.
3638 embedding_init: embedding initializer.
39+ rngs: rng keys.
3740 """
3841
3942 def __init__ (
@@ -43,6 +46,7 @@ def __init__(
4346 dtype : jnp .dtype | None = None ,
4447 param_dtype : jnp .dtype = jnp .bfloat16 ,
4548 promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
49+ embedding_init : nnx .Initializer = default_embed_init ,
4650 rngs : nnx .Rngs = None ,
4751 ):
4852 """
@@ -63,9 +67,7 @@ def __init__(
6367 rngs: Random number generator state for parameter initialization.
6468 """
6569 self .embedding = nnx .Param (
66- nnx .with_partitioning (default_embed_init , (None , None ))(
67- jax .random .PRNGKey (0 ), (num_embeddings , features ), param_dtype
68- )
70+ embedding_init (jax .random .PRNGKey (0 ), (num_embeddings , features ), param_dtype )
6971 )
7072
7173 self .num_embeddings = num_embeddings
@@ -124,6 +126,7 @@ def __init__(
124126 dtype : jnp .dtype | None = None ,
125127 param_dtype : jnp .dtype = jnp .bfloat16 ,
126128 promote_dtype : PromoteDtypeFn = dtypes .promote_dtype ,
129+ embedding_init : nnx .Initializer = default_embed_init ,
127130 rngs : nnx .Rngs = None ,
128131 use_bias : bool = False ,
129132 ):
@@ -148,6 +151,7 @@ def __init__(
148151 dtype = dtype ,
149152 param_dtype = param_dtype ,
150153 promote_dtype = promote_dtype ,
154+ embedding_init = embedding_init ,
151155 rngs = rngs ,
152156 )
153157 if use_bias :
0 commit comments