Skip to content

Commit 4b555a9

Browse files
authored
Adjusted Longrope embedding function to match Huggingface Implementation (#18422)
This updated implementation of longrope allows for the consideration of `long_factors` and `short_factors`, which are scaling dictionaries provided via HF configs for MSFT's Phi3+ models. In the HF canonical implementation of longrope, once the sequence length exceeds a certain pre-configured dimension, you must use a different set of `ext_factors` than you were previously. This patch enables this by packing both sets of scaling factors into one argument, and selecting which to use dynamically within the returned `prim_func`. The HF implementation of this can be found here: https://github.com/huggingface/transformers/blob/7b325cd573e40bbb12951b8446176c96e8b1afaa/src/transformers/modeling_rope_utils.py#L521 The link above points directly to the switching logic between long and short factors, which has been replicated in this PR.
1 parent d013dad commit 4b555a9

File tree

1 file changed

+75
-32
lines changed

1 file changed

+75
-32
lines changed

python/tvm/relax/frontend/nn/llm/position_embedding.py

Lines changed: 75 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,10 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments
464464
rotary_dim = head_dim
465465
scale = tir.const(scale, "float32")
466466
is_longrope_scaling = rope_scaling.get("rope_type") == "longrope"
467+
if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling:
468+
original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
469+
else:
470+
original_max_position_embeddings = 0
467471

468472
def _rope( # pylint: disable=too-many-arguments
469473
x: T.Buffer,
@@ -546,7 +550,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
546550
var_q: T.handle,
547551
var_k: T.handle,
548552
var_v: T.handle,
549-
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
553+
ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore
550554
):
551555
T.func_attr(
552556
{
@@ -563,37 +567,76 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
563567
position_map = T.match_buffer(
564568
var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset
565569
)
566-
for iters in T.grid(seq_len, fused_heads, head_dim):
567-
with T.block("llama_fused_rope"):
568-
s, h, d = T.axis.remap("SSS", iters)
569-
if h < num_q_heads:
570-
q[s, h, d] = T.if_then_else(
571-
d < rotary_dim,
572-
_rope(
573-
qkv,
574-
s,
575-
h,
576-
d,
577-
position_map[s],
578-
ext_factors if is_longrope_scaling else None,
579-
),
580-
qkv[s, h, d],
581-
)
582-
elif h < num_q_heads + num_kv_heads:
583-
k[s, h - num_q_heads, d] = T.if_then_else(
584-
d < rotary_dim,
585-
_rope(
586-
qkv,
587-
s,
588-
h,
589-
d,
590-
position_map[s],
591-
ext_factors if is_longrope_scaling else None,
592-
),
593-
qkv[s, h, d],
594-
)
595-
else:
596-
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
570+
# long factors is the first half, short factors is the second half
571+
long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data)
572+
short_factors = T.Buffer(
573+
(rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2)
574+
)
575+
576+
if seq_len > original_max_position_embeddings:
577+
for iters in T.grid(seq_len, fused_heads, head_dim):
578+
with T.block("llama_fused_rope"):
579+
s, h, d = T.axis.remap("SSS", iters)
580+
if h < num_q_heads:
581+
q[s, h, d] = T.if_then_else(
582+
d < rotary_dim,
583+
_rope(
584+
qkv,
585+
s,
586+
h,
587+
d,
588+
position_map[s],
589+
long_factors if is_longrope_scaling else None,
590+
),
591+
qkv[s, h, d],
592+
)
593+
elif h < num_q_heads + num_kv_heads:
594+
k[s, h - num_q_heads, d] = T.if_then_else(
595+
d < rotary_dim,
596+
_rope(
597+
qkv,
598+
s,
599+
h,
600+
d,
601+
position_map[s],
602+
long_factors if is_longrope_scaling else None,
603+
),
604+
qkv[s, h, d],
605+
)
606+
else:
607+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
608+
else:
609+
for iters in T.grid(seq_len, fused_heads, head_dim):
610+
with T.block("llama_fused_rope"):
611+
s, h, d = T.axis.remap("SSS", iters)
612+
if h < num_q_heads:
613+
q[s, h, d] = T.if_then_else(
614+
d < rotary_dim,
615+
_rope(
616+
qkv,
617+
s,
618+
h,
619+
d,
620+
position_map[s],
621+
short_factors if is_longrope_scaling else None,
622+
),
623+
qkv[s, h, d],
624+
)
625+
elif h < num_q_heads + num_kv_heads:
626+
k[s, h - num_q_heads, d] = T.if_then_else(
627+
d < rotary_dim,
628+
_rope(
629+
qkv,
630+
s,
631+
h,
632+
d,
633+
position_map[s],
634+
short_factors if is_longrope_scaling else None,
635+
),
636+
qkv[s, h, d],
637+
)
638+
else:
639+
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]
597640

598641
if is_longrope_scaling:
599642
return fused_rope_longrope_scaling

0 commit comments

Comments
 (0)