Skip to content

Commit 106df4e

Browse files
llflXucSh
andauthored
Fix mrope_positions size when req is retracted (#13700)
Signed-off-by: Kun(llfl) <[email protected]> Co-authored-by: Xuchun Shang <[email protected]>
1 parent 427a19b commit 106df4e

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

python/sglang/srt/managers/mm_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,3 +835,45 @@ def hash_feature(f):
835835
reconstruct_t = f.reconstruct_on_target_device(torch.cuda.current_device())
836836
return tensor_hash([reconstruct_t])
837837
return data_hash(f)
838+
839+
840+
def extend_mrope_positions_for_retracted_request(
841+
mrope_positions: torch.Tensor, output_ids_len: int
842+
) -> torch.Tensor:
843+
"""
844+
Extend mrope_positions for retracted requests by appending positions for output_ids.
845+
846+
When a request is retracted and has multimodal inputs with mrope_positions,
847+
we need to extend the positions to cover the output_ids that were already generated.
848+
For pure text tokens, all three dimensions use the same incremental sequence.
849+
850+
Args:
851+
mrope_positions: The original mrope positions tensor, shape (3, origin_input_ids_len)
852+
output_ids_len: The number of output tokens to generate positions for
853+
854+
Returns:
855+
Extended mrope_positions tensor with shape (3, origin_input_ids_len + output_ids_len)
856+
"""
857+
if output_ids_len <= 0:
858+
return mrope_positions
859+
860+
# Get the last position value corresponding to origin_input_ids
861+
# mrope_positions shape: (3, origin_input_ids_len)
862+
last_position = mrope_positions[:, -1] # shape: (3,)
863+
864+
# Generate pure text mrope positions for output_ids
865+
# All three dimensions for pure text are the same incremental sequence
866+
start_pos = last_position[0] + 1 # Start from last position + 1
867+
output_positions = (
868+
torch.arange(
869+
start_pos,
870+
start_pos + output_ids_len,
871+
dtype=torch.int64,
872+
device=mrope_positions.device,
873+
)
874+
.unsqueeze(0)
875+
.expand(3, -1)
876+
) # shape: (3, output_ids_len)
877+
878+
# Concatenate to the original mrope_positions
879+
return torch.cat([mrope_positions, output_positions], dim=1)

python/sglang/srt/managers/schedule_batch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,22 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
810810
match_result.host_hit_length,
811811
)
812812
self.cache_protected_len = len(self.prefix_indices)
813+
814+
if (
815+
self.is_retracted
816+
and self.multimodal_inputs is not None
817+
and self.multimodal_inputs.mrope_positions is not None
818+
):
819+
from sglang.srt.managers.mm_utils import (
820+
extend_mrope_positions_for_retracted_request,
821+
)
822+
823+
self.multimodal_inputs.mrope_positions = (
824+
extend_mrope_positions_for_retracted_request(
825+
self.multimodal_inputs.mrope_positions, len(self.output_ids)
826+
)
827+
)
828+
813829
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
814830

815831
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313

0 commit comments

Comments
 (0)