@@ -665,6 +665,97 @@ async def cancel_responses(response_id: str, raw_request: Request):
665665 return JSONResponse (content = response .model_dump ())
666666
667667
668+ if envs .VLLM_V1_SPANS_ENABLED :
669+ import spnl
670+ import time
671+ from fastapi import Body
672+ from vllm import SamplingParams
673+ from vllm .inputs import TokensPrompt
674+ from vllm .outputs import RequestOutput
675+ from vllm .entrypoints .openai .protocol import (ChatMessage ,ChatCompletionResponseChoice ,UsageInfo )
676+ spnl_state = spnl .init (10 )
677+ PAD_TOKEN = 27
678+ PLUS_TOKEN = envs .VLLM_V1_SPANS_TOKEN_PLUS if envs .VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
679+ CROSS_TOKEN = envs .VLLM_V1_SPANS_TOKEN_CROSS if envs .VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
680+ def wrap (prompt : str | list [str ]) -> TokensPrompt :
681+ if isinstance (prompt [0 ], list ):
682+ return [TokensPrompt (prompt_token_ids = p ) for p in prompt ]
683+ return TokensPrompt (prompt_token_ids = prompt )
684+ @router .post ("/v1/query/prepare" )
685+ @with_cancellation
686+ @load_aware_call
687+ async def prepare_query (raw_request : Request ,
688+ query : str = Body (..., media_type = "text/plain" )):
689+ docs = [wrap (doc ) for doc in spnl .tokenize_prepare (
690+ spnl_state ,
691+ query ,
692+ True , # we need to preload the prefix of the plus/independent spans
693+ PAD_TOKEN ,
694+ PLUS_TOKEN ,
695+ raw_request .app .state .vllm_config .cache_config .block_size
696+ )]
697+ logger .debug (f"/v1/query/prepare { len (docs )} { docs } " )
698+
699+ request_id = raw_request .headers .get (
700+ "X-Request-Id" ) or uuid .uuid4 ().hex
701+ client = engine_client (raw_request )
702+ generators = [client .generate (doc , SamplingParams (temperature = 0 ,max_tokens = 1 ), request_id ) for doc in docs ]
703+ for generator in generators :
704+ async for res in generator :
705+ final = res .outputs [0 ]
706+
707+ return JSONResponse (content = {"success" : True })
708+
709+ @router .post ("/v1/query/execute" )
710+ @with_cancellation
711+ @load_aware_call
712+ async def execute_query (raw_request : Request ,
713+ query : str = Body (..., media_type = "text/plain" )):
714+ req = spnl .tokenize_query (
715+ spnl_state ,
716+ query ,
717+ PAD_TOKEN ,
718+ CROSS_TOKEN ,
719+ PLUS_TOKEN ,
720+ raw_request .app .state .vllm_config .cache_config .block_size
721+ )
722+ logger .debug (f"/v1/query/execute { req .messages } " )
723+
724+ request_id = raw_request .headers .get (
725+ "X-Request-Id" ) or uuid .uuid4 ().hex
726+ client = engine_client (raw_request )
727+ generator = client .generate (wrap (req .messages ), SamplingParams (temperature = req .temperature if req .temperature is not None else 0 ,max_tokens = req .max_tokens if req .max_tokens is not None and req .max_tokens != 0 else 2048 ), request_id )
728+
729+ # TODO streaming output...
730+ final_res : Optional [RequestOutput ] = None
731+ async for res in generator :
732+ final_res = res
733+ final = final_res .outputs [0 ]
734+ choices = [
735+ ChatCompletionResponseChoice (
736+ index = 0 ,
737+ message = ChatMessage (role = "assistant" , content = final .text ),
738+ logprobs = final .logprobs ,
739+ finish_reason = final .finish_reason ,
740+ stop_reason = final .stop_reason ,
741+ )
742+ ]
743+ num_prompt_tokens = 0 # TODO
744+ num_generated_tokens = 0 # TODO
745+ usage = UsageInfo (prompt_tokens = num_prompt_tokens ,
746+ completion_tokens = num_generated_tokens ,
747+ total_tokens = num_prompt_tokens +
748+ num_generated_tokens )
749+ response = ChatCompletionResponse (
750+ id = request_id ,
751+ created = int (time .time ()),
752+ model = req .model ,
753+ choices = choices ,
754+ usage = usage
755+ )
756+
757+ return JSONResponse (content = response .model_dump ())
758+
668759@router .post ("/v1/chat/completions" ,
669760 dependencies = [Depends (validate_json_request )],
670761 responses = {
0 commit comments