Skip to content

Commit f5234f3

Browse files
committed
fix head dim pad
1 parent d9c521b commit f5234f3

File tree

8 files changed

+22
-25
lines changed

8 files changed

+22
-25
lines changed

python/sgl_jax/srt/layers/embeddings.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ class Embed(nnx.Module):
3333
num_embeddings: number of embeddings.
3434
features: number of feature dimensions for each embedding.
3535
dtype: the dtype of the embedding vectors (default: float32).
36-
param_dtype: the dtype of the embedding parameters.
37-
promote_dtype: the dtype promotion function.
3836
embedding_init: embedding initializer.
39-
rngs: rng keys.
4037
"""
4138

4239
def __init__(
@@ -46,7 +43,6 @@ def __init__(
4643
dtype: jnp.dtype | None = None,
4744
param_dtype: jnp.dtype = jnp.bfloat16,
4845
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
49-
embedding_init: nnx.Initializer = default_embed_init,
5046
rngs: nnx.Rngs = None,
5147
):
5248
"""
@@ -67,7 +63,9 @@ def __init__(
6763
rngs: Random number generator state for parameter initialization.
6864
"""
6965
self.embedding = nnx.Param(
70-
embedding_init(jax.random.PRNGKey(0), (num_embeddings, features), param_dtype)
66+
nnx.with_partitioning(default_embed_init, (None, None))(
67+
jax.random.PRNGKey(0), (num_embeddings, features), param_dtype
68+
)
7169
)
7270

7371
self.num_embeddings = num_embeddings
@@ -126,7 +124,6 @@ def __init__(
126124
dtype: jnp.dtype | None = None,
127125
param_dtype: jnp.dtype = jnp.bfloat16,
128126
promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
129-
embedding_init: nnx.Initializer = default_embed_init,
130127
rngs: nnx.Rngs = None,
131128
use_bias: bool = False,
132129
):
@@ -151,7 +148,6 @@ def __init__(
151148
dtype=dtype,
152149
param_dtype=param_dtype,
153150
promote_dtype=promote_dtype,
154-
embedding_init=embedding_init,
155151
rngs=rngs,
156152
)
157153
if use_bias:

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def load_model(self):
205205
self.model_config.configure_for_tensor_parallel(self.tp_size)
206206
self.model_config.log_kv_heads_info(self.tp_size)
207207
self.model_config.hf_config.ep_size = self.ep_size
208-
self.model_config.hf_config.head_dim = self.model_config.get_padded_head_dim()
208+
self.model_config.hf_config.head_dim_padded = self.model_config.get_padded_head_dim()
209209

210210
self.model = self.model_loader.load_model(
211211
model_config=self.model_config,

python/sgl_jax/srt/models/bailing_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(
210210
rope_theta = getattr(config, "rope_theta", 1000000)
211211
rope_scaling = getattr(config, "rope_scaling", None)
212212
max_position_embeddings = getattr(config, "max_position_embeddings", 40960)
213-
self.head_dim = getattr(config, "head_dim", None)
213+
self.head_dim = getattr(config, "head_dim_padded", None)
214214
use_qk_norm = getattr(config, "use_qk_norm", False)
215215
if hasattr(config, "partial_rotary_factor"):
216216
rotary_dim = int(self.head_dim * config.partial_rotary_factor)

python/sgl_jax/srt/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def __init__(
216216
# Support internlm/internlm-7b with bias
217217
attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False)
218218

219-
head_dim = getattr(config, "head_dim", None)
219+
head_dim = getattr(config, "head_dim_padded", None)
220220
self.self_attn = LlamaAttention(
221221
hidden_size=self.hidden_size,
222222
num_heads=config.num_attention_heads,

python/sgl_jax/srt/models/qwen.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ def __init__(
7777
max_position_embeddings: int,
7878
rope_theta: float = 10000,
7979
rope_scaling: dict[str, Any] | None = None,
80+
head_dim: int | None = None,
8081
layer_id: int = 0,
8182
dtype: jnp.dtype = jnp.float16,
8283
rngs: nnx.Rngs = None,
8384
):
8485
self.hidden_size = hidden_size
8586
self.num_heads = num_heads
86-
head_size = hidden_size // num_heads
87-
self.head_size = head_size
88-
self.scaling = head_size**-0.5
87+
self.head_dim = head_dim or hidden_size // num_heads
88+
self.scaling = head_dim**-0.5
8989

9090
self.q_proj = LinearBase(
9191
input_size=hidden_size,
@@ -112,7 +112,7 @@ def __init__(
112112
params_dtype=dtype,
113113
)
114114
self.c_proj = LinearBase(
115-
input_size=num_heads * head_size,
115+
input_size=num_heads * head_dim,
116116
output_size=hidden_size,
117117
use_bias=False,
118118
kernel_axes=("tensor", None),
@@ -122,17 +122,17 @@ def __init__(
122122

123123
# Use torch version of RotaryEmbedding directly
124124
self.rotary_emb = RotaryEmbedding(
125-
head_size=head_size,
126-
rotary_dim=head_size,
125+
head_size=head_dim,
126+
rotary_dim=head_dim,
127127
max_position_embeddings=max_position_embeddings,
128128
base=rope_theta,
129129
is_neox_style=True,
130130
dtype=dtype,
131131
)
132-
self.scaling = head_size**-0.5
132+
self.scaling = head_dim**-0.5
133133
self.attn = RadixAttention(
134134
num_heads=num_heads,
135-
head_dim=head_size,
135+
head_dim=head_dim,
136136
scaling=self.scaling,
137137
num_kv_heads=num_heads,
138138
layer_id=layer_id,
@@ -150,9 +150,9 @@ def __call__(
150150
k, _ = self.k_proj(hidden_states)
151151
v, _ = self.v_proj(hidden_states)
152152

153-
q = q.reshape(-1, self.num_heads, self.head_size)
154-
k = k.reshape(-1, self.num_heads, self.head_size)
155-
v = v.reshape(-1, self.num_heads, self.head_size)
153+
q = q.reshape(-1, self.num_heads, self.head_dim)
154+
k = k.reshape(-1, self.num_heads, self.head_dim)
155+
v = v.reshape(-1, self.num_heads, self.head_dim)
156156

157157
q, k = self.rotary_emb(positions, q, k)
158158
attn_output, kv_fused = self.attn(q, k, v, forward_batch, token_to_kv_pool)
@@ -169,7 +169,7 @@ def __init__(
169169
rngs: nnx.Rngs = None,
170170
):
171171
self.layer_id = layer_id
172-
172+
head_dim = getattr(config, "head_dim_padded", None)
173173
self.ln_1 = RMSNorm(
174174
config.hidden_size,
175175
epsilon=config.layer_norm_epsilon,
@@ -186,6 +186,7 @@ def __init__(
186186
config.max_position_embeddings,
187187
rope_theta=rope_theta,
188188
rope_scaling=rope_scaling,
189+
head_dim=head_dim,
189190
layer_id=layer_id,
190191
dtype=dtype,
191192
rngs=rngs,

python/sgl_jax/srt/models/qwen2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(
179179
rope_theta = getattr(config, "rope_theta", 1000000)
180180
rope_scaling = getattr(config, "rope_scaling", None)
181181
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
182-
head_dim = getattr(config, "head_dim", None)
182+
head_dim = getattr(config, "head_dim_padded", None)
183183
self.self_attn = Qwen2Attention(
184184
hidden_size=config.hidden_size,
185185
num_heads=config.num_attention_heads,

python/sgl_jax/srt/models/qwen3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198
rope_theta = getattr(config, "rope_theta", 1000000)
199199
rope_scaling = getattr(config, "rope_scaling", None)
200200
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
201-
head_dim = getattr(config, "head_dim", None)
201+
head_dim = getattr(config, "head_dim_padded", None)
202202
self.self_attn = QWen3Attention(
203203
hidden_size=config.hidden_size,
204204
num_heads=config.num_attention_heads,

python/sgl_jax/srt/models/qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __init__(
152152
rope_theta = getattr(config, "rope_theta", 1000000)
153153
rope_scaling = getattr(config, "rope_scaling", None)
154154
max_position_embeddings = getattr(config, "max_position_embeddings", 40960)
155-
head_dim = getattr(config, "head_dim", None)
155+
head_dim = getattr(config, "head_dim_padded", None)
156156

157157
self.self_attn = QWen3MoeAttention(
158158
hidden_size=config.hidden_size,

0 commit comments

Comments
 (0)