Skip to content

Commit 8bcb9a3

Browse files
authored
fix: Use Rust Ingress (dynamo-run) for the Frontend (#1391) (#1399)
1 parent b785163 commit 8bcb9a3

File tree

19 files changed

+250
-1186
lines changed

19 files changed

+250
-1186
lines changed

examples/tensorrt_llm/common/base_engine.py

Lines changed: 89 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,20 @@
2121
import signal
2222
import threading
2323
from contextlib import asynccontextmanager
24-
from dataclasses import asdict
2524
from enum import Enum
2625
from queue import Queue
2726
from typing import Any, Optional
2827

2928
from common.parser import LLMAPIConfig
30-
from common.protocol import (
31-
DisaggregatedTypeConverter,
32-
TRTLLMWorkerRequest,
33-
TRTLLMWorkerResponse,
34-
TRTLLMWorkerResponseOutput,
35-
)
29+
from common.protocol import DisaggregatedTypeConverter
3630
from common.utils import ManagedThread, ServerType
3731
from tensorrt_llm.executor import CppExecutorError
3832
from tensorrt_llm.llmapi import LLM, SamplingParams
3933
from tensorrt_llm.llmapi.disagg_utils import (
4034
CtxGenServerConfig,
4135
parse_disagg_config_file,
4236
)
37+
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
4338
from tensorrt_llm.serve.openai_protocol import DisaggregatedParams
4439

4540
from dynamo.llm import KvEventPublisher, KvMetricsPublisher
@@ -65,14 +60,26 @@ def update_args_from_disagg_config(
6560
return engine_config
6661

6762

68-
def get_sampling_params(sampling_params):
69-
# Removes keys starting with '_' from the sampling params which gets
70-
# added by the LLM API. TRTLLM does not support creating SamplingParams
71-
# from a dictionary with keys starting with '_'.
72-
cleaned_dict = {
73-
key: value for key, value in sampling_params.items() if not key.startswith("_")
74-
}
75-
return SamplingParams(**cleaned_dict)
63+
def _to_signed_i64(value: int | None) -> int | None:
64+
"""Convert a Python int to signed 64-bit range by two's complement."""
65+
if value is None:
66+
return None
67+
68+
if value >= 2**63:
69+
return value - 2**64
70+
if value < -(2**63):
71+
return ((value + 2**63) % 2**64) - 2**63
72+
return value
73+
74+
75+
def get_sampling_params(sampling_params_dict, default_sampling_params):
76+
sampling_params = copy.deepcopy(default_sampling_params)
77+
for key, value in sampling_params_dict.items():
78+
if value is None:
79+
continue
80+
if hasattr(sampling_params, key):
81+
setattr(sampling_params, key, value)
82+
return sampling_params
7683

7784

7885
class BaseTensorrtLLMEngine:
@@ -161,6 +168,12 @@ def _init_engine(self):
161168
target=asyncio.run, args=(self._run_llm_engine(),)
162169
)
163170

171+
# Populate default sampling params from the model
172+
tokenizer = tokenizer_factory(self._engine_config.model_name)
173+
self._default_sampling_params = SamplingParams()
174+
self._default_sampling_params._setup(tokenizer)
175+
self._default_sampling_params.stop = None
176+
164177
self.publish_kv_cache_events_thread = None
165178
self.publish_stats_thread = None
166179

@@ -308,13 +321,13 @@ async def publish_kv_cache_events_task(self):
308321
event_id = event["event_id"]
309322
data = event["data"]
310323
if data["type"] == "stored":
311-
parent_hash = data["parent_hash"]
324+
parent_hash = _to_signed_i64(data["parent_hash"])
312325
token_ids = []
313326
num_block_tokens = []
314327
block_hashes = []
315328
for block in data["blocks"]:
316329
token_num_in_block = len(block["tokens"])
317-
block_hash = block["block_hash"]
330+
block_hash = _to_signed_i64(block["block_hash"])
318331
if token_num_in_block > self._kv_block_size:
319332
logger.error(
320333
f"Block {block_hash} contains {token_num_in_block} tokens, which is greater than kv_block_size {self._kv_block_size}"
@@ -350,6 +363,7 @@ async def publish_kv_cache_events_task(self):
350363
elif data["type"] == "removed":
351364
block_hashes = []
352365
for block_hash in data["block_hashes"]:
366+
block_hash = _to_signed_i64(block_hash)
353367
if block_hash in self._partial_block_hashes:
354368
logger.debug(
355369
f"Skipping removing block hash {block_hash} since it is a partial block"
@@ -458,15 +472,16 @@ async def async_llm_wrapper():
458472

459473
async def _get_remote_prefill_response(self, request):
460474
prefill_request = copy.deepcopy(request)
461-
prefill_request.sampling_params["max_tokens"] = 1
475+
# TRTLLM requires max_tokens to be set for prefill requests.
476+
prefill_request.stop_conditions.max_tokens = 1
462477
prefill_request.disaggregated_params = DisaggregatedParams(
463478
request_type=DisaggRequestType.CONTEXT_ONLY.value
464479
)
465480

466481
if self._prefill_client is None:
467482
raise ValueError("Prefill client not initialized")
468483

469-
# TODO: Use smart KV router to determine which prefill worker to use.
484+
# TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
470485
ctx_responses = [
471486
ctx_response
472487
async for ctx_response in await self._prefill_client.round_robin(
@@ -480,17 +495,10 @@ async def _get_remote_prefill_response(self, request):
480495
logger.debug(
481496
f"Received response from prefill worker: {ctx_responses[0].data()}"
482497
)
483-
ctx_response_obj = TRTLLMWorkerResponse.model_validate_json(
484-
ctx_responses[0].data()
485-
)
486-
ctx_response_obj.outputs = [
487-
TRTLLMWorkerResponseOutput(**ctx_response_obj.outputs[0])
488-
]
489-
assert ctx_response_obj.outputs[0].disaggregated_params is not None
490-
491-
return ctx_response_obj
498+
remote_prefill_response = ctx_responses[0]
499+
return remote_prefill_response
492500

493-
async def generate(self, request: TRTLLMWorkerRequest):
501+
async def generate(self, request):
494502
if self._llm_engine is None:
495503
raise RuntimeError("Engine not initialized")
496504

@@ -500,32 +508,41 @@ async def generate(self, request: TRTLLMWorkerRequest):
500508
self._ongoing_request_count += 1
501509

502510
try:
503-
worker_inputs = request.tokens.tokens
511+
worker_inputs = request.token_ids
504512

505513
disaggregated_params = (
506514
DisaggregatedTypeConverter.to_llm_disaggregated_params(
507515
request.disaggregated_params
508516
)
509517
)
510518

511-
if self._remote_prefill and self._server_type == ServerType.GEN:
512-
ctx_response_obj = await self._get_remote_prefill_response(request)
519+
num_output_tokens_so_far = 0
513520

514-
yield TRTLLMWorkerResponse(
515-
request_id=request.id,
516-
prompt_token_ids=ctx_response_obj.prompt_token_ids,
517-
outputs=[asdict(ctx_response_obj.outputs[0])],
518-
finished=ctx_response_obj.finished,
519-
).model_dump_json(exclude_unset=True)
521+
if self._remote_prefill and self._server_type == ServerType.GEN:
522+
ctx_response = await self._get_remote_prefill_response(request)
523+
remote_prefill_response = ctx_response.data()
524+
if (
525+
remote_prefill_response["finish_reason"] == "stop"
526+
or remote_prefill_response["finish_reason"] == "error"
527+
):
528+
yield remote_prefill_response
529+
return
530+
num_output_tokens_so_far = len(remote_prefill_response["token_ids"])
520531

521-
worker_inputs = ctx_response_obj.prompt_token_ids
532+
# Decode the disaggregated params from the remote prefill response
522533
disaggregated_params = (
523534
DisaggregatedTypeConverter.to_llm_disaggregated_params(
524535
DisaggregatedParams(
525-
**ctx_response_obj.outputs[0].disaggregated_params
536+
**remote_prefill_response["disaggregated_params"]
526537
)
527538
)
528539
)
540+
541+
# Send the first token response to the client
542+
first_token_response = remote_prefill_response
543+
first_token_response.pop("disaggregated_params")
544+
yield first_token_response
545+
529546
disaggregated_params.request_type = (
530547
DisaggRequestType.GENERATION_ONLY.value
531548
)
@@ -534,29 +551,44 @@ async def generate(self, request: TRTLLMWorkerRequest):
534551
f"Worker inputs: {worker_inputs}, disaggregated params: {disaggregated_params}"
535552
)
536553

537-
sampling_params = get_sampling_params(request.sampling_params)
554+
sampling_params = get_sampling_params(
555+
request.sampling_options.dict(), self._default_sampling_params
556+
)
557+
max_tokens = request.stop_conditions.max_tokens
558+
if max_tokens:
559+
sampling_params.max_tokens = max_tokens
560+
538561
async for response in self._llm_engine.generate_async(
539562
inputs=worker_inputs,
540563
sampling_params=sampling_params,
541564
disaggregated_params=disaggregated_params,
542-
streaming=False
543-
if self._server_type == ServerType.CTX
544-
else request.streaming,
565+
streaming=self._server_type != ServerType.CTX,
545566
):
546-
# Convert the disaggregated params to OAI format so
547-
# it can be sent over the network.
548-
response.outputs[
549-
0
550-
].disaggregated_params = DisaggregatedTypeConverter.to_oai_disaggregated_params(
551-
response.outputs[0].disaggregated_params
552-
)
567+
if response.finished and self._server_type != ServerType.CTX:
568+
yield {"finish_reason": "stop", "token_ids": []}
569+
break
570+
571+
if not response.outputs:
572+
yield {"finish_reason": "error", "token_ids": []}
573+
break
553574

554-
yield TRTLLMWorkerResponse(
555-
request_id=request.id,
556-
prompt_token_ids=response.prompt_token_ids,
557-
outputs=[asdict(response.outputs[0])],
558-
finished=response.finished,
559-
).model_dump_json(exclude_unset=True)
575+
output = response.outputs[0]
576+
next_total_toks = len(output.token_ids)
577+
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
578+
if output.finish_reason:
579+
out["finish_reason"] = output.finish_reason
580+
if output.stop_reason:
581+
out["stop_reason"] = output.stop_reason
582+
if self._server_type == ServerType.CTX:
583+
# Return the disaggregated params only when operating in prefill mode.
584+
out[
585+
"disaggregated_params"
586+
] = DisaggregatedTypeConverter.to_oai_disaggregated_params(
587+
output.disaggregated_params
588+
).dict()
589+
590+
yield out
591+
num_output_tokens_so_far = next_total_toks
560592

561593
except CppExecutorError:
562594
signal.raise_signal(signal.SIGINT)

0 commit comments

Comments
 (0)