Skip to content

Commit a9d18b5

Browse files
authored
[Bugfix] Fix gpt_oss packed_modules_mapping (#28536)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent edb59a9 commit a9d18b5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

vllm/model_executor/models/gpt_oss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292
self.scaling = self.head_dim**-0.5
9393
self.rope_theta = config.rope_theta
9494

95-
self.qkv = QKVParallelLinear(
95+
self.qkv_proj = QKVParallelLinear(
9696
hidden_size=self.hidden_size,
9797
head_size=self.head_dim,
9898
total_num_heads=self.num_attention_heads,
@@ -129,7 +129,7 @@ def __init__(
129129
def forward(
130130
self, hidden_states: torch.Tensor, positions: torch.Tensor
131131
) -> torch.Tensor:
132-
qkv, _ = self.qkv(hidden_states)
132+
qkv, _ = self.qkv_proj(hidden_states)
133133
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
134134
q, k = self.rotary_emb(positions, q, k)
135135
v = v.contiguous()
@@ -606,9 +606,9 @@ def _load_weights_other(
606606
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
607607
stacked_params_mapping = [
608608
# (param_name, shard_name, shard_id)
609-
(".qkv", ".q_proj", "q"),
610-
(".qkv", ".k_proj", "k"),
611-
(".qkv", ".v_proj", "v"),
609+
(".qkv_proj", ".q_proj", "q"),
610+
(".qkv_proj", ".k_proj", "k"),
611+
(".qkv_proj", ".v_proj", "v"),
612612
]
613613

614614
tp_rank = get_tensor_model_parallel_rank()

0 commit comments

Comments
 (0)