2121import signal
2222import threading
2323from contextlib import asynccontextmanager
24- from dataclasses import asdict
2524from enum import Enum
2625from queue import Queue
2726from typing import Any , Optional
2827
2928from common .parser import LLMAPIConfig
30- from common .protocol import (
31- DisaggregatedTypeConverter ,
32- TRTLLMWorkerRequest ,
33- TRTLLMWorkerResponse ,
34- TRTLLMWorkerResponseOutput ,
35- )
29+ from common .protocol import DisaggregatedTypeConverter
3630from common .utils import ManagedThread , ServerType
3731from tensorrt_llm .executor import CppExecutorError
3832from tensorrt_llm .llmapi import LLM , SamplingParams
3933from tensorrt_llm .llmapi .disagg_utils import (
4034 CtxGenServerConfig ,
4135 parse_disagg_config_file ,
4236)
37+ from tensorrt_llm .llmapi .tokenizer import tokenizer_factory
4338from tensorrt_llm .serve .openai_protocol import DisaggregatedParams
4439
4540from dynamo .llm import KvEventPublisher , KvMetricsPublisher
@@ -65,14 +60,26 @@ def update_args_from_disagg_config(
6560 return engine_config
6661
6762
68- def get_sampling_params (sampling_params ):
69- # Removes keys starting with '_' from the sampling params which gets
70- # added by the LLM API. TRTLLM does not support creating SamplingParams
71- # from a dictionary with keys starting with '_'.
72- cleaned_dict = {
73- key : value for key , value in sampling_params .items () if not key .startswith ("_" )
74- }
75- return SamplingParams (** cleaned_dict )
63+ def _to_signed_i64 (value : int | None ) -> int | None :
64+ """Convert a Python int to signed 64-bit range by two's complement."""
65+ if value is None :
66+ return None
67+
68+ if value >= 2 ** 63 :
69+ return value - 2 ** 64
70+ if value < - (2 ** 63 ):
71+ return ((value + 2 ** 63 ) % 2 ** 64 ) - 2 ** 63
72+ return value
73+
74+
75+ def get_sampling_params (sampling_params_dict , default_sampling_params ):
76+ sampling_params = copy .deepcopy (default_sampling_params )
77+ for key , value in sampling_params_dict .items ():
78+ if value is None :
79+ continue
80+ if hasattr (sampling_params , key ):
81+ setattr (sampling_params , key , value )
82+ return sampling_params
7683
7784
7885class BaseTensorrtLLMEngine :
@@ -161,6 +168,12 @@ def _init_engine(self):
161168 target = asyncio .run , args = (self ._run_llm_engine (),)
162169 )
163170
171+ # Populate default sampling params from the model
172+ tokenizer = tokenizer_factory (self ._engine_config .model_name )
173+ self ._default_sampling_params = SamplingParams ()
174+ self ._default_sampling_params ._setup (tokenizer )
175+ self ._default_sampling_params .stop = None
176+
164177 self .publish_kv_cache_events_thread = None
165178 self .publish_stats_thread = None
166179
@@ -308,13 +321,13 @@ async def publish_kv_cache_events_task(self):
308321 event_id = event ["event_id" ]
309322 data = event ["data" ]
310323 if data ["type" ] == "stored" :
311- parent_hash = data ["parent_hash" ]
324+ parent_hash = _to_signed_i64 ( data ["parent_hash" ])
312325 token_ids = []
313326 num_block_tokens = []
314327 block_hashes = []
315328 for block in data ["blocks" ]:
316329 token_num_in_block = len (block ["tokens" ])
317- block_hash = block ["block_hash" ]
330+ block_hash = _to_signed_i64 ( block ["block_hash" ])
318331 if token_num_in_block > self ._kv_block_size :
319332 logger .error (
320333 f"Block { block_hash } contains { token_num_in_block } tokens, which is greater than kv_block_size { self ._kv_block_size } "
@@ -350,6 +363,7 @@ async def publish_kv_cache_events_task(self):
350363 elif data ["type" ] == "removed" :
351364 block_hashes = []
352365 for block_hash in data ["block_hashes" ]:
366+ block_hash = _to_signed_i64 (block_hash )
353367 if block_hash in self ._partial_block_hashes :
354368 logger .debug (
355369 f"Skipping removing block hash { block_hash } since it is a partial block"
@@ -458,15 +472,16 @@ async def async_llm_wrapper():
458472
459473 async def _get_remote_prefill_response (self , request ):
460474 prefill_request = copy .deepcopy (request )
461- prefill_request .sampling_params ["max_tokens" ] = 1
475+ # TRTLLM requires max_tokens to be set for prefill requests.
476+ prefill_request .stop_conditions .max_tokens = 1
462477 prefill_request .disaggregated_params = DisaggregatedParams (
463478 request_type = DisaggRequestType .CONTEXT_ONLY .value
464479 )
465480
466481 if self ._prefill_client is None :
467482 raise ValueError ("Prefill client not initialized" )
468483
469- # TODO: Use smart KV router to determine which prefill worker to use.
484+ # TODO: Use smart KV router to determine which prefill worker to use. This would also require supporting publishing events for prefill workers.
470485 ctx_responses = [
471486 ctx_response
472487 async for ctx_response in await self ._prefill_client .round_robin (
@@ -480,17 +495,10 @@ async def _get_remote_prefill_response(self, request):
480495 logger .debug (
481496 f"Received response from prefill worker: { ctx_responses [0 ].data ()} "
482497 )
483- ctx_response_obj = TRTLLMWorkerResponse .model_validate_json (
484- ctx_responses [0 ].data ()
485- )
486- ctx_response_obj .outputs = [
487- TRTLLMWorkerResponseOutput (** ctx_response_obj .outputs [0 ])
488- ]
489- assert ctx_response_obj .outputs [0 ].disaggregated_params is not None
490-
491- return ctx_response_obj
498+ remote_prefill_response = ctx_responses [0 ]
499+ return remote_prefill_response
492500
493- async def generate (self , request : TRTLLMWorkerRequest ):
501+ async def generate (self , request ):
494502 if self ._llm_engine is None :
495503 raise RuntimeError ("Engine not initialized" )
496504
@@ -500,32 +508,41 @@ async def generate(self, request: TRTLLMWorkerRequest):
500508 self ._ongoing_request_count += 1
501509
502510 try :
503- worker_inputs = request .tokens . tokens
511+ worker_inputs = request .token_ids
504512
505513 disaggregated_params = (
506514 DisaggregatedTypeConverter .to_llm_disaggregated_params (
507515 request .disaggregated_params
508516 )
509517 )
510518
511- if self ._remote_prefill and self ._server_type == ServerType .GEN :
512- ctx_response_obj = await self ._get_remote_prefill_response (request )
519+ num_output_tokens_so_far = 0
513520
514- yield TRTLLMWorkerResponse (
515- request_id = request .id ,
516- prompt_token_ids = ctx_response_obj .prompt_token_ids ,
517- outputs = [asdict (ctx_response_obj .outputs [0 ])],
518- finished = ctx_response_obj .finished ,
519- ).model_dump_json (exclude_unset = True )
521+ if self ._remote_prefill and self ._server_type == ServerType .GEN :
522+ ctx_response = await self ._get_remote_prefill_response (request )
523+ remote_prefill_response = ctx_response .data ()
524+ if (
525+ remote_prefill_response ["finish_reason" ] == "stop"
526+ or remote_prefill_response ["finish_reason" ] == "error"
527+ ):
528+ yield remote_prefill_response
529+ return
530+ num_output_tokens_so_far = len (remote_prefill_response ["token_ids" ])
520531
521- worker_inputs = ctx_response_obj . prompt_token_ids
532+ # Decode the disaggregated params from the remote prefill response
522533 disaggregated_params = (
523534 DisaggregatedTypeConverter .to_llm_disaggregated_params (
524535 DisaggregatedParams (
525- ** ctx_response_obj . outputs [ 0 ]. disaggregated_params
536+ ** remote_prefill_response [ " disaggregated_params" ]
526537 )
527538 )
528539 )
540+
541+ # Send the first token response to the client
542+ first_token_response = remote_prefill_response
543+ first_token_response .pop ("disaggregated_params" )
544+ yield first_token_response
545+
529546 disaggregated_params .request_type = (
530547 DisaggRequestType .GENERATION_ONLY .value
531548 )
@@ -534,29 +551,44 @@ async def generate(self, request: TRTLLMWorkerRequest):
534551 f"Worker inputs: { worker_inputs } , disaggregated params: { disaggregated_params } "
535552 )
536553
537- sampling_params = get_sampling_params (request .sampling_params )
554+ sampling_params = get_sampling_params (
555+ request .sampling_options .dict (), self ._default_sampling_params
556+ )
557+ max_tokens = request .stop_conditions .max_tokens
558+ if max_tokens :
559+ sampling_params .max_tokens = max_tokens
560+
538561 async for response in self ._llm_engine .generate_async (
539562 inputs = worker_inputs ,
540563 sampling_params = sampling_params ,
541564 disaggregated_params = disaggregated_params ,
542- streaming = False
543- if self ._server_type == ServerType .CTX
544- else request .streaming ,
565+ streaming = self ._server_type != ServerType .CTX ,
545566 ):
546- # Convert the disaggregated params to OAI format so
547- # it can be sent over the network.
548- response . outputs [
549- 0
550- ]. disaggregated_params = DisaggregatedTypeConverter . to_oai_disaggregated_params (
551- response . outputs [ 0 ]. disaggregated_params
552- )
567+ if response . finished and self . _server_type != ServerType . CTX :
568+ yield { "finish_reason" : "stop" , "token_ids" : []}
569+ break
570+
571+ if not response . outputs :
572+ yield { "finish_reason" : "error" , "token_ids" : []}
573+ break
553574
554- yield TRTLLMWorkerResponse (
555- request_id = request .id ,
556- prompt_token_ids = response .prompt_token_ids ,
557- outputs = [asdict (response .outputs [0 ])],
558- finished = response .finished ,
559- ).model_dump_json (exclude_unset = True )
575+ output = response .outputs [0 ]
576+ next_total_toks = len (output .token_ids )
577+ out = {"token_ids" : output .token_ids [num_output_tokens_so_far :]}
578+ if output .finish_reason :
579+ out ["finish_reason" ] = output .finish_reason
580+ if output .stop_reason :
581+ out ["stop_reason" ] = output .stop_reason
582+ if self ._server_type == ServerType .CTX :
583+ # Return the disaggregated params only when operating in prefill mode.
584+ out [
585+ "disaggregated_params"
586+ ] = DisaggregatedTypeConverter .to_oai_disaggregated_params (
587+ output .disaggregated_params
588+ ).dict ()
589+
590+ yield out
591+ num_output_tokens_so_far = next_total_toks
560592
561593 except CppExecutorError :
562594 signal .raise_signal (signal .SIGINT )
0 commit comments