Skip to content

Commit 2eecd31

Browse files
committed
Update the rest
Signed-off-by: Harry Mellor <[email protected]>
1 parent 19dcc18 commit 2eecd31

File tree

5 files changed

+40
-62
lines changed

5 files changed

+40
-62
lines changed

vllm/model_executor/models/apertus.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __init__(
119119
num_heads: int,
120120
num_kv_heads: int,
121121
rope_theta: float = 10000,
122-
rope_scaling: dict[str, Any] | None = None,
122+
rope_parameters: dict[str, Any] | None = None,
123123
max_position_embeddings: int = 8192,
124124
quant_config: QuantizationConfig | None = None,
125125
bias: bool = False,
@@ -177,7 +177,7 @@ def __init__(
177177
)
178178

179179
self._init_rotary_emb(
180-
config, rope_scaling=rope_scaling, quant_config=quant_config
180+
config, rope_parameters=rope_parameters, quant_config=quant_config
181181
)
182182

183183
sliding_window = None
@@ -224,7 +224,7 @@ def forward(
224224
def _init_rotary_emb(
225225
self,
226226
config: ApertusConfig,
227-
rope_scaling: dict[str, Any] | None,
227+
rope_parameters: dict[str, Any] | None,
228228
quant_config: QuantizationConfig | None,
229229
) -> None:
230230
is_neox_style = True
@@ -237,7 +237,7 @@ def _init_rotary_emb(
237237
rotary_dim=int(self.partial_rotary_factor * self.head_dim),
238238
max_position=self.max_position_embeddings,
239239
base=self.rope_theta,
240-
rope_scaling=rope_scaling,
240+
rope_parameters=rope_parameters,
241241
is_neox_style=is_neox_style,
242242
partial_rotary_factor=self.partial_rotary_factor,
243243
)
@@ -253,14 +253,8 @@ def __init__(
253253
) -> None:
254254
super().__init__()
255255
self.hidden_size = config.hidden_size
256-
rope_theta = getattr(config, "rope_theta", 10000)
257-
rope_scaling = getattr(config, "rope_scaling", None)
258-
if rope_scaling is not None and getattr(
259-
config, "original_max_position_embeddings", None
260-
):
261-
rope_scaling["original_max_position_embeddings"] = (
262-
config.original_max_position_embeddings
263-
)
256+
if ompe := getattr(config, "original_max_position_embeddings", None):
257+
config.rope_parameters["original_max_position_embeddings"] = ompe
264258
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
265259
# Support abacusai/Smaug-72B-v0.1 with attention_bias
266260
# Support internlm/internlm-7b with bias
@@ -288,8 +282,8 @@ def __init__(
288282
num_kv_heads=getattr(
289283
config, "num_key_value_heads", config.num_attention_heads
290284
),
291-
rope_theta=rope_theta,
292-
rope_scaling=rope_scaling,
285+
rope_theta=config.rope_parameters["rope_theta"],
286+
rope_parameters=config.rope_parameters,
293287
max_position_embeddings=max_position_embeddings,
294288
quant_config=quant_config,
295289
bias=attention_bias,

vllm/model_executor/models/deepseek_v2.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def __init__(
410410
q_lora_rank: int,
411411
kv_lora_rank: int,
412412
rope_theta: float = 10000,
413-
rope_scaling: dict[str, Any] | None = None,
413+
rope_parameters: dict[str, Any] | None = None,
414414
max_position_embeddings: int = 8192,
415415
cache_config: CacheConfig | None = None,
416416
quant_config: QuantizationConfig | None = None,
@@ -485,21 +485,21 @@ def __init__(
485485
quant_config=quant_config,
486486
prefix=f"{prefix}.o_proj",
487487
)
488-
if rope_scaling:
489-
rope_scaling["rope_type"] = "deepseek_yarn"
488+
if rope_parameters:
489+
rope_parameters["rope_type"] = "deepseek_yarn"
490490

491491
self.rotary_emb = get_rope(
492492
qk_rope_head_dim,
493493
rotary_dim=qk_rope_head_dim,
494494
max_position=max_position_embeddings,
495495
base=rope_theta,
496-
rope_scaling=rope_scaling,
496+
rope_parameters=rope_parameters,
497497
is_neox_style=False,
498498
)
499499

500-
if rope_scaling:
501-
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
502-
scaling_factor = rope_scaling["factor"]
500+
if rope_parameters:
501+
mscale_all_dim = rope_parameters.get("mscale_all_dim", False)
502+
scaling_factor = rope_parameters["factor"]
503503
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
504504
self.scaling = self.scaling * mscale * mscale
505505

@@ -904,7 +904,7 @@ def __init__(
904904
q_lora_rank: int | None,
905905
kv_lora_rank: int,
906906
rope_theta: float = 10000,
907-
rope_scaling: dict[str, Any] | None = None,
907+
rope_parameters: dict[str, Any] | None = None,
908908
max_position_embeddings: int = 8192,
909909
cache_config: CacheConfig | None = None,
910910
quant_config: QuantizationConfig | None = None,
@@ -981,19 +981,19 @@ def __init__(
981981
prefix=f"{prefix}.o_proj",
982982
)
983983

984-
if rope_scaling:
985-
rope_scaling["rope_type"] = "deepseek_yarn"
984+
if rope_parameters:
985+
rope_parameters["rope_type"] = "deepseek_yarn"
986986
self.rotary_emb = get_rope(
987987
qk_rope_head_dim,
988988
rotary_dim=qk_rope_head_dim,
989989
max_position=max_position_embeddings,
990990
base=rope_theta,
991-
rope_scaling=rope_scaling,
991+
rope_parameters=rope_parameters,
992992
is_neox_style=False,
993993
)
994-
if rope_scaling:
995-
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
996-
scaling_factor = rope_scaling["factor"]
994+
if rope_parameters:
995+
mscale_all_dim = rope_parameters.get("mscale_all_dim", False)
996+
scaling_factor = rope_parameters["factor"]
997997
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
998998
self.scaling = self.scaling * mscale * mscale
999999

@@ -1073,8 +1073,6 @@ def __init__(
10731073
parallel_config = vllm_config.parallel_config
10741074

10751075
self.hidden_size = config.hidden_size
1076-
rope_theta = getattr(config, "rope_theta", 10000)
1077-
rope_scaling = getattr(config, "rope_scaling", None)
10781076
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
10791077
moe_layer_freq = getattr(config, "moe_layer_freq", 1)
10801078
# DecoderLayers are created with `make_layers` which passes the prefix
@@ -1107,8 +1105,8 @@ def __init__(
11071105
v_head_dim=v_head_dim,
11081106
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
11091107
kv_lora_rank=kv_lora_rank,
1110-
rope_theta=rope_theta,
1111-
rope_scaling=rope_scaling,
1108+
rope_theta=config.rope_parameters["rope_theta"],
1109+
rope_parameters=config.rope_parameters,
11121110
max_position_embeddings=max_position_embeddings,
11131111
cache_config=cache_config,
11141112
quant_config=quant_config,

vllm/model_executor/models/exaone4.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def __init__(
111111
num_heads: int,
112112
num_kv_heads: int,
113113
rope_theta: float = 1000000,
114-
rope_scaling: dict[str, Any] | None = None,
114+
rope_parameters: dict[str, Any] | None = None,
115115
max_position_embeddings: int = 8192,
116116
quant_config: QuantizationConfig | None = None,
117117
bias: bool = False,
@@ -181,7 +181,7 @@ def __init__(
181181
rotary_dim=self.head_dim,
182182
max_position=max_position_embeddings,
183183
base=rope_theta,
184-
rope_scaling=rope_scaling,
184+
rope_parameters=rope_parameters,
185185
is_neox_style=is_neox_style,
186186
)
187187
self.attn = Attention(
@@ -227,14 +227,8 @@ def __init__(
227227
) -> None:
228228
super().__init__()
229229
self.hidden_size = config.hidden_size
230-
rope_theta = getattr(config, "rope_theta", 1000000)
231-
rope_scaling = getattr(config, "rope_scaling", None)
232-
if rope_scaling is not None and getattr(
233-
config, "original_max_position_embeddings", None
234-
):
235-
rope_scaling["original_max_position_embeddings"] = (
236-
config.original_max_position_embeddings
237-
)
230+
if ompe := getattr(config, "original_max_position_embeddings", None):
231+
config.rope_parameters["original_max_position_embeddings"] = ompe
238232
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
239233
# Support abacusai/Smaug-72B-v0.1 with attention_bias
240234
# Support internlm/internlm-7b with bias
@@ -249,8 +243,8 @@ def __init__(
249243
num_kv_heads=getattr(
250244
config, "num_key_value_heads", config.num_attention_heads
251245
),
252-
rope_theta=rope_theta,
253-
rope_scaling=rope_scaling,
246+
rope_theta=config.rope_parameters["rope_theta"],
247+
rope_parameters=config.rope_parameters,
254248
max_position_embeddings=max_position_embeddings,
255249
quant_config=quant_config,
256250
bias=attention_bias,

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,8 @@ def __init__(
274274
self.head_dim,
275275
rotary_dim=self.head_dim,
276276
max_position=config.max_position_embeddings,
277-
base=int(config.rope_theta),
278-
rope_scaling=config.rope_scaling
279-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None
280-
else None,
277+
base=int(config.rope_parameters["rope_theta"]),
278+
rope_parameters=config.rope_parameters,
281279
is_neox_style=True,
282280
)
283281
else:

vllm/model_executor/models/llama.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
num_heads: int,
122122
num_kv_heads: int,
123123
rope_theta: float = 10000,
124-
rope_scaling: dict[str, Any] | None = None,
124+
rope_parameters: dict[str, Any] | None = None,
125125
max_position_embeddings: int = 8192,
126126
quant_config: QuantizationConfig | None = None,
127127
bias: bool = False,
@@ -187,7 +187,7 @@ def __init__(
187187
)
188188

189189
self._init_rotary_emb(
190-
config, rope_scaling=rope_scaling, quant_config=quant_config
190+
config, rope_parameters=rope_parameters, quant_config=quant_config
191191
)
192192

193193
sliding_window = None
@@ -258,7 +258,7 @@ def forward(
258258
def _init_rotary_emb(
259259
self,
260260
config: LlamaConfig,
261-
rope_scaling: dict[str, Any] | None,
261+
rope_parameters: dict[str, Any] | None,
262262
quant_config: QuantizationConfig | None,
263263
) -> None:
264264
is_neox_style = True
@@ -271,7 +271,7 @@ def _init_rotary_emb(
271271
rotary_dim=self.head_dim,
272272
max_position=self.max_position_embeddings,
273273
base=self.rope_theta,
274-
rope_scaling=rope_scaling,
274+
rope_parameters=rope_parameters,
275275
is_neox_style=is_neox_style,
276276
partial_rotary_factor=self.partial_rotary_factor,
277277
)
@@ -291,14 +291,8 @@ def __init__(
291291
quant_config = self.get_quant_config(vllm_config)
292292

293293
self.hidden_size = config.hidden_size
294-
rope_theta = getattr(config, "rope_theta", 10000)
295-
rope_scaling = getattr(config, "rope_scaling", None)
296-
if rope_scaling is not None and getattr(
297-
config, "original_max_position_embeddings", None
298-
):
299-
rope_scaling["original_max_position_embeddings"] = (
300-
config.original_max_position_embeddings
301-
)
294+
if ompe := getattr(config, "original_max_position_embeddings", None):
295+
config.rope_parameters["original_max_position_embeddings"] = ompe
302296
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
303297
# Support abacusai/Smaug-72B-v0.1 with attention_bias
304298
# Support internlm/internlm-7b with bias
@@ -326,8 +320,8 @@ def __init__(
326320
num_kv_heads=getattr(
327321
config, "num_key_value_heads", config.num_attention_heads
328322
),
329-
rope_theta=rope_theta,
330-
rope_scaling=rope_scaling,
323+
rope_theta=config.rope_parameters["rope_theta"],
324+
rope_parameters=config.rope_parameters,
331325
max_position_embeddings=max_position_embeddings,
332326
quant_config=quant_config,
333327
bias=attention_bias,

0 commit comments

Comments
 (0)