Skip to content

Commit fa0742e

Browse files
committed
fix head dim pad
1 parent f5234f3 commit fa0742e

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

python/sgl_jax/srt/layers/embeddings.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ class Embed(nnx.Module):
3333
num_embeddings: number of embeddings.
3434
features: number of feature dimensions for each embedding.
3535
dtype: the dtype of the embedding vectors (default: float32).
36+
param_dtype: the dtype of the embedding parameters.
37+
promote_dtype: the dtype promotion function.
3638
embedding_init: embedding initializer.
39+
rngs: rng keys.
3740
"""
3841

3942
def __init__(
@@ -43,6 +46,7 @@ def __init__(
4346
dtype: jnp.dtype | None = None,
4447
param_dtype: jnp.dtype = jnp.bfloat16,
4548
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
49+
embedding_init: nnx.Initializer = default_embed_init,
4650
rngs: nnx.Rngs = None,
4751
):
4852
"""
@@ -63,9 +67,7 @@ def __init__(
6367
rngs: Random number generator state for parameter initialization.
6468
"""
6569
self.embedding = nnx.Param(
66-
nnx.with_partitioning(default_embed_init, (None, None))(
67-
jax.random.PRNGKey(0), (num_embeddings, features), param_dtype
68-
)
70+
embedding_init(jax.random.PRNGKey(0), (num_embeddings, features), param_dtype)
6971
)
7072

7173
self.num_embeddings = num_embeddings
@@ -124,6 +126,7 @@ def __init__(
124126
dtype: jnp.dtype | None = None,
125127
param_dtype: jnp.dtype = jnp.bfloat16,
126128
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
129+
embedding_init: nnx.Initializer = default_embed_init,
127130
rngs: nnx.Rngs = None,
128131
use_bias: bool = False,
129132
):
@@ -148,6 +151,7 @@ def __init__(
148151
dtype=dtype,
149152
param_dtype=param_dtype,
150153
promote_dtype=promote_dtype,
154+
embedding_init=embedding_init,
151155
rngs=rngs,
152156
)
153157
if use_bias:

0 commit comments

Comments
 (0)