Skip to content

Commit 72ee472

Browse files
committed
feat: /v1/query/{prepare,execute} endpoints for span queries
This PR adds two endpoints - `/v1/query/prepare` - `/v1/query/execute` These add REST APIs around the spans core feature. They take as input a [span query](https://github.com/IBM/spnl). This PR also adds an example query under examples/offline-inference/spans/query-{ab,ba},json. See the [spans readme](examples/offline-inference/spans/README.md) for an example of usage. Signed-off-by: Nick Mitchell <[email protected]>
1 parent 1c3cd0e commit 72ee472

File tree

5 files changed

+166
-0
lines changed

5 files changed

+166
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Span (a.k.a. Block Attention) Examples
2+
3+
## Span Queries
4+
5+
This directory contains a [span query](https://github.com/IBM/spnl#readme). To send a query, first prepare the query shape:
6+
7+
```bash
8+
curl -s -XPOST http://localhost:8000/v1/query/prepare --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
9+
1.504452
10+
```
11+
12+
And then you can execute the query in either order, and you should see millisecond-level TTFT:
13+
14+
```bash
15+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ba.json -o /dev/null -w "%{time_total}\n"
16+
0.077699
17+
```
18+
19+
```bash
20+
curl -s -XPOST http://localhost:8000/v1/query/execute --data @./query-ab.json -o /dev/null -w "%{time_total}\n"
21+
0.078419
22+
```

examples/offline_inference/spans/query-ab.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

examples/offline_inference/spans/query-ba.json

Lines changed: 26 additions & 0 deletions
Large diffs are not rendered by default.

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,4 @@ pybase64 # fast base64 implementation
4949
cbor2 # Required for cross-language serialization of hashable objects
5050
setproctitle # Used to set process names for better debugging and monitoring
5151
openai-harmony >= 0.0.3 # Required for gpt-oss
52+
spnl >= 0.6.1

vllm/entrypoints/openai/api_server.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,97 @@ async def cancel_responses(response_id: str, raw_request: Request):
665665
return JSONResponse(content=response.model_dump())
666666

667667

668+
if envs.VLLM_V1_SPANS_ENABLED:
669+
import spnl
670+
import time
671+
from fastapi import Body
672+
from vllm import SamplingParams
673+
from vllm.inputs import TokensPrompt
674+
from vllm.outputs import RequestOutput
675+
from vllm.entrypoints.openai.protocol import (ChatMessage,ChatCompletionResponseChoice,UsageInfo)
676+
spnl_state = spnl.init(10)
677+
PAD_TOKEN = 27
678+
PLUS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_PLUS if envs.VLLM_V1_SPANS_TOKEN_PLUS >= 0 else None
679+
CROSS_TOKEN = envs.VLLM_V1_SPANS_TOKEN_CROSS if envs.VLLM_V1_SPANS_TOKEN_CROSS >= 0 else None
680+
def wrap(prompt: str | list[str]) -> TokensPrompt:
681+
if isinstance(prompt[0], list):
682+
return [TokensPrompt(prompt_token_ids=p) for p in prompt]
683+
return TokensPrompt(prompt_token_ids=prompt)
684+
@router.post("/v1/query/prepare")
685+
@with_cancellation
686+
@load_aware_call
687+
async def prepare_query(raw_request: Request,
688+
query: str = Body(..., media_type="text/plain")):
689+
docs = [wrap(doc) for doc in spnl.tokenize_prepare(
690+
spnl_state,
691+
query,
692+
True, # we need to preload the prefix of the plus/independent spans
693+
PAD_TOKEN,
694+
PLUS_TOKEN,
695+
raw_request.app.state.vllm_config.cache_config.block_size
696+
)]
697+
logger.debug(f"/v1/query/prepare {len(docs)} {docs}")
698+
699+
request_id = raw_request.headers.get(
700+
"X-Request-Id") or uuid.uuid4().hex
701+
client = engine_client(raw_request)
702+
generators = [client.generate(doc, SamplingParams(temperature=0,max_tokens=1), request_id) for doc in docs]
703+
for generator in generators:
704+
async for res in generator:
705+
final = res.outputs[0]
706+
707+
return JSONResponse(content={"success": True})
708+
709+
@router.post("/v1/query/execute")
710+
@with_cancellation
711+
@load_aware_call
712+
async def execute_query(raw_request: Request,
713+
query: str = Body(..., media_type="text/plain")):
714+
req = spnl.tokenize_query(
715+
spnl_state,
716+
query,
717+
PAD_TOKEN,
718+
CROSS_TOKEN,
719+
PLUS_TOKEN,
720+
raw_request.app.state.vllm_config.cache_config.block_size
721+
)
722+
logger.debug(f"/v1/query/execute {req.messages}")
723+
724+
request_id = raw_request.headers.get(
725+
"X-Request-Id") or uuid.uuid4().hex
726+
client = engine_client(raw_request)
727+
generator = client.generate(wrap(req.messages), SamplingParams(temperature=req.temperature if req.temperature is not None else 0,max_tokens=req.max_tokens if req.max_tokens is not None and req.max_tokens != 0 else 2048), request_id)
728+
729+
# TODO streaming output...
730+
final_res: Optional[RequestOutput] = None
731+
async for res in generator:
732+
final_res = res
733+
final = final_res.outputs[0]
734+
choices = [
735+
ChatCompletionResponseChoice(
736+
index=0,
737+
message=ChatMessage(role="assistant", content=final.text),
738+
logprobs=final.logprobs,
739+
finish_reason=final.finish_reason,
740+
stop_reason=final.stop_reason,
741+
)
742+
]
743+
num_prompt_tokens=0 # TODO
744+
num_generated_tokens=0 # TODO
745+
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
746+
completion_tokens=num_generated_tokens,
747+
total_tokens=num_prompt_tokens +
748+
num_generated_tokens)
749+
response = ChatCompletionResponse(
750+
id=request_id,
751+
created=int(time.time()),
752+
model=req.model,
753+
choices=choices,
754+
usage=usage
755+
)
756+
757+
return JSONResponse(content=response.model_dump())
758+
668759
@router.post("/v1/chat/completions",
669760
dependencies=[Depends(validate_json_request)],
670761
responses={

0 commit comments

Comments
 (0)