Skip to content

Commit 8daf5e4

Browse files
committed
fix head dim pad
1 parent de20719 commit 8daf5e4

File tree

8 files changed

+22
-17
lines changed

8 files changed

+22
-17
lines changed

python/sgl_jax/srt/configs/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
169169
**kwargs,
170170
)
171171

172+
def get_padded_head_dim(self) -> int:
173+
return (self.head_dim + 127) // 128 * 128
174+
172175
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
173176
def get_total_num_kv_heads(self) -> int:
174177
"""Returns the total number of KV heads (original, not replicated)."""

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def load_model(self):
205205
self.model_config.configure_for_tensor_parallel(self.tp_size)
206206
self.model_config.log_kv_heads_info(self.tp_size)
207207
self.model_config.hf_config.ep_size = self.ep_size
208+
self.model_config.hf_config.head_dim_padded = self.model_config.get_padded_head_dim()
208209

209210
self.model = self.model_loader.load_model(
210211
model_config=self.model_config,

python/sgl_jax/srt/models/bailing_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(
210210
rope_theta = getattr(config, "rope_theta", 1000000)
211211
rope_scaling = getattr(config, "rope_scaling", None)
212212
max_position_embeddings = getattr(config, "max_position_embeddings", 40960)
213-
self.head_dim = getattr(config, "head_dim", None)
213+
self.head_dim = getattr(config, "head_dim_padded", None)
214214
use_qk_norm = getattr(config, "use_qk_norm", False)
215215
if hasattr(config, "partial_rotary_factor"):
216216
rotary_dim = int(self.head_dim * config.partial_rotary_factor)

python/sgl_jax/srt/models/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def __init__(
216216
# Support internlm/internlm-7b with bias
217217
attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False)
218218

219-
head_dim = getattr(config, "head_dim", None)
219+
head_dim = getattr(config, "head_dim_padded", None)
220220
self.self_attn = LlamaAttention(
221221
hidden_size=self.hidden_size,
222222
num_heads=config.num_attention_heads,

python/sgl_jax/srt/models/qwen.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ def __init__(
7777
max_position_embeddings: int,
7878
rope_theta: float = 10000,
7979
rope_scaling: dict[str, Any] | None = None,
80+
head_dim: int | None = None,
8081
layer_id: int = 0,
8182
dtype: jnp.dtype = jnp.float16,
8283
rngs: nnx.Rngs = None,
8384
):
8485
self.hidden_size = hidden_size
8586
self.num_heads = num_heads
86-
head_size = hidden_size // num_heads
87-
self.head_size = head_size
88-
self.scaling = head_size**-0.5
87+
self.head_dim = head_dim or hidden_size // num_heads
88+
self.scaling = head_dim**-0.5
8989

9090
self.q_proj = LinearBase(
9191
input_size=hidden_size,
@@ -112,7 +112,7 @@ def __init__(
112112
params_dtype=dtype,
113113
)
114114
self.c_proj = LinearBase(
115-
input_size=num_heads * head_size,
115+
input_size=num_heads * head_dim,
116116
output_size=hidden_size,
117117
use_bias=False,
118118
kernel_axes=("tensor", None),
@@ -122,17 +122,17 @@ def __init__(
122122

123123
# Use torch version of RotaryEmbedding directly
124124
self.rotary_emb = RotaryEmbedding(
125-
head_size=head_size,
126-
rotary_dim=head_size,
125+
head_size=head_dim,
126+
rotary_dim=head_dim,
127127
max_position_embeddings=max_position_embeddings,
128128
base=rope_theta,
129129
is_neox_style=True,
130130
dtype=dtype,
131131
)
132-
self.scaling = head_size**-0.5
132+
self.scaling = head_dim**-0.5
133133
self.attn = RadixAttention(
134134
num_heads=num_heads,
135-
head_dim=head_size,
135+
head_dim=head_dim,
136136
scaling=self.scaling,
137137
num_kv_heads=num_heads,
138138
layer_id=layer_id,
@@ -150,9 +150,9 @@ def __call__(
150150
k, _ = self.k_proj(hidden_states)
151151
v, _ = self.v_proj(hidden_states)
152152

153-
q = q.reshape(-1, self.num_heads, self.head_size)
154-
k = k.reshape(-1, self.num_heads, self.head_size)
155-
v = v.reshape(-1, self.num_heads, self.head_size)
153+
q = q.reshape(-1, self.num_heads, self.head_dim)
154+
k = k.reshape(-1, self.num_heads, self.head_dim)
155+
v = v.reshape(-1, self.num_heads, self.head_dim)
156156

157157
q, k = self.rotary_emb(positions, q, k)
158158
attn_output, kv_fused = self.attn(q, k, v, forward_batch, token_to_kv_pool)
@@ -169,7 +169,7 @@ def __init__(
169169
rngs: nnx.Rngs = None,
170170
):
171171
self.layer_id = layer_id
172-
172+
head_dim = getattr(config, "head_dim_padded", None)
173173
self.ln_1 = RMSNorm(
174174
config.hidden_size,
175175
epsilon=config.layer_norm_epsilon,
@@ -186,6 +186,7 @@ def __init__(
186186
config.max_position_embeddings,
187187
rope_theta=rope_theta,
188188
rope_scaling=rope_scaling,
189+
head_dim=head_dim,
189190
layer_id=layer_id,
190191
dtype=dtype,
191192
rngs=rngs,

python/sgl_jax/srt/models/qwen2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(
179179
rope_theta = getattr(config, "rope_theta", 1000000)
180180
rope_scaling = getattr(config, "rope_scaling", None)
181181
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
182-
head_dim = getattr(config, "head_dim", None)
182+
head_dim = getattr(config, "head_dim_padded", None)
183183
self.self_attn = Qwen2Attention(
184184
hidden_size=config.hidden_size,
185185
num_heads=config.num_attention_heads,

python/sgl_jax/srt/models/qwen3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198
rope_theta = getattr(config, "rope_theta", 1000000)
199199
rope_scaling = getattr(config, "rope_scaling", None)
200200
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
201-
head_dim = getattr(config, "head_dim", None)
201+
head_dim = getattr(config, "head_dim_padded", None)
202202
self.self_attn = QWen3Attention(
203203
hidden_size=config.hidden_size,
204204
num_heads=config.num_attention_heads,

python/sgl_jax/srt/models/qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __init__(
152152
rope_theta = getattr(config, "rope_theta", 1000000)
153153
rope_scaling = getattr(config, "rope_scaling", None)
154154
max_position_embeddings = getattr(config, "max_position_embeddings", 40960)
155-
head_dim = getattr(config, "head_dim", None)
155+
head_dim = getattr(config, "head_dim_padded", None)
156156

157157
self.self_attn = QWen3MoeAttention(
158158
hidden_size=config.hidden_size,

0 commit comments

Comments
 (0)