@@ -835,3 +835,45 @@ def hash_feature(f):
835835 reconstruct_t = f .reconstruct_on_target_device (torch .cuda .current_device ())
836836 return tensor_hash ([reconstruct_t ])
837837 return data_hash (f )
838+
839+
840+ def extend_mrope_positions_for_retracted_request (
841+ mrope_positions : torch .Tensor , output_ids_len : int
842+ ) -> torch .Tensor :
843+ """
844+ Extend mrope_positions for retracted requests by appending positions for output_ids.
845+
846+ When a request is retracted and has multimodal inputs with mrope_positions,
847+ we need to extend the positions to cover the output_ids that were already generated.
848+ For pure text tokens, all three dimensions use the same incremental sequence.
849+
850+ Args:
851+ mrope_positions: The original mrope positions tensor, shape (3, origin_input_ids_len)
852+ output_ids_len: The number of output tokens to generate positions for
853+
854+ Returns:
855+ Extended mrope_positions tensor with shape (3, origin_input_ids_len + output_ids_len)
856+ """
857+ if output_ids_len <= 0 :
858+ return mrope_positions
859+
860+ # Get the last position value corresponding to origin_input_ids
861+ # mrope_positions shape: (3, origin_input_ids_len)
862+ last_position = mrope_positions [:, - 1 ] # shape: (3,)
863+
864+ # Generate pure text mrope positions for output_ids
865+ # All three dimensions for pure text are the same incremental sequence
866+ start_pos = last_position [0 ] + 1 # Start from last position + 1
867+ output_positions = (
868+ torch .arange (
869+ start_pos ,
870+ start_pos + output_ids_len ,
871+ dtype = torch .int64 ,
872+ device = mrope_positions .device ,
873+ )
874+ .unsqueeze (0 )
875+ .expand (3 , - 1 )
876+ ) # shape: (3, output_ids_len)
877+
878+ # Concatenate to the original mrope_positions
879+ return torch .cat ([mrope_positions , output_positions ], dim = 1 )
0 commit comments