1414 limitations under the License.
1515
1616*/
17- use std:: {
18- collections:: { BTreeMap , HashMap } ,
19- sync:: { Arc , Mutex } ,
20- } ;
17+ use std:: { collections:: HashMap , sync:: Arc } ;
2118
22- use dashmap:: DashMap ;
2319use futures:: { StreamExt , TryStreamExt , stream} ;
2420use opentelemetry:: trace:: TraceId ;
2521use tokio:: sync:: mpsc;
@@ -37,8 +33,8 @@ use crate::{
3733 Context , Error ,
3834 common:: { self , text_contents_detections, validate_detectors} ,
3935 types:: {
40- ChatCompletionBatcher , ChatCompletionStream , ChatMessageIterator , ChoiceIndex , Chunk ,
41- DetectionBatchStream , Detections ,
36+ ChatCompletionBatcher , ChatCompletionStream , ChatMessageIterator , Chunk ,
37+ CompletionState , DetectionBatchStream , Detections ,
4238 } ,
4339 } ,
4440} ;
@@ -237,7 +233,7 @@ async fn handle_output_detection(
237233 detectors. into_iter ( ) . partition ( |( detector_id, _) | {
238234 ctx. config . get_chunker_id ( detector_id) . unwrap ( ) == "whole_doc_chunker"
239235 } ) ;
240- let chat_completion_state = Arc :: new ( ChatCompletionState :: new ( ) ) ;
236+ let completion_state = Arc :: new ( CompletionState :: new ( ) ) ;
241237
242238 if !detectors. is_empty ( ) {
243239 // Set up streaming detection pipeline
@@ -279,7 +275,7 @@ async fn handle_output_detection(
279275 tokio:: spawn ( process_chat_completion_stream (
280276 trace_id,
281277 chat_completion_stream,
282- Some ( chat_completion_state . clone ( ) ) ,
278+ Some ( completion_state . clone ( ) ) ,
283279 Some ( input_txs) ,
284280 None ,
285281 ) ) ;
@@ -290,7 +286,7 @@ async fn handle_output_detection(
290286 ) ;
291287 process_detection_batch_stream (
292288 trace_id,
293- chat_completion_state . clone ( ) ,
289+ completion_state . clone ( ) ,
294290 detection_batch_stream,
295291 response_tx. clone ( ) ,
296292 )
@@ -301,7 +297,7 @@ async fn handle_output_detection(
301297 process_chat_completion_stream (
302298 trace_id,
303299 chat_completion_stream,
304- Some ( chat_completion_state . clone ( ) ) ,
300+ Some ( completion_state . clone ( ) ) ,
305301 None ,
306302 Some ( response_tx. clone ( ) ) ,
307303 )
@@ -310,12 +306,12 @@ async fn handle_output_detection(
310306 // NOTE: at this point, the chat completions stream has been fully consumed and chat completion state is final
311307
312308 // If whole doc output detections or usage is requested, a final message is sent with these items
313- if !whole_doc_detectors. is_empty ( ) || chat_completion_state . usage ( ) . is_some ( ) {
309+ if !whole_doc_detectors. is_empty ( ) || completion_state . usage ( ) . is_some ( ) {
314310 let mut chat_completion = ChatCompletionChunk {
315- id : chat_completion_state . id ( ) ,
316- created : chat_completion_state . created ( ) ,
317- model : chat_completion_state . model ( ) ,
318- usage : chat_completion_state . usage ( ) ,
311+ id : completion_state . id ( ) . unwrap ( ) . to_string ( ) ,
312+ created : completion_state . created ( ) . unwrap ( ) ,
313+ model : completion_state . model ( ) . unwrap ( ) . to_string ( ) ,
314+ usage : completion_state . usage ( ) . cloned ( ) ,
319315 ..Default :: default ( )
320316 } ;
321317 if !whole_doc_detectors. is_empty ( ) {
@@ -324,7 +320,7 @@ async fn handle_output_detection(
324320 ctx. clone ( ) ,
325321 task,
326322 whole_doc_detectors,
327- chat_completion_state ,
323+ completion_state ,
328324 )
329325 . await
330326 {
@@ -352,7 +348,7 @@ async fn handle_output_detection(
352348async fn process_chat_completion_stream (
353349 trace_id : TraceId ,
354350 mut chat_completion_stream : ChatCompletionStream ,
355- chat_completion_state : Option < Arc < ChatCompletionState > > ,
351+ completion_state : Option < Arc < CompletionState < ChatCompletionChunk > > > ,
356352 input_txs : Option < HashMap < u32 , mpsc:: Sender < Result < ( usize , String ) , Error > > > > ,
357353 response_tx : Option < mpsc:: Sender < Result < Option < ChatCompletionChunk > , Error > > > ,
358354) {
@@ -372,42 +368,37 @@ async fn process_chat_completion_stream(
372368 return ;
373369 }
374370 }
375- if chat_completion. usage . is_some ( ) {
376- // Set usage state from the usage message
371+ if let Some ( usage) = & chat_completion. usage
372+ && chat_completion. choices . is_empty ( )
373+ {
374+ // Update state: set usage
377375 // NOTE: this message has no choices and is not sent to detection input channel
378- if let Some ( state) = & chat_completion_state {
379- state. metadata . lock ( ) . unwrap ( ) . usage = chat_completion . usage . clone ( ) ;
376+ if let Some ( state) = & completion_state {
377+ state. set_usage ( usage. clone ( ) ) ;
380378 }
381379 } else {
382380 if message_index == 0 {
383- // Set metadata state from the first message
381+ // Update state: set metadata
384382 // NOTE: these values are the same for all chat completion chunks
385- if let Some ( state) = & chat_completion_state {
386- let mut metadata = state. metadata . lock ( ) . unwrap ( ) ;
387- metadata. id = chat_completion. id . clone ( ) ;
388- metadata. created = chat_completion. created ;
389- metadata. model = chat_completion. model . clone ( ) ;
383+ if let Some ( state) = & completion_state {
384+ state. set_metadata (
385+ chat_completion. id . clone ( ) ,
386+ chat_completion. created ,
387+ chat_completion. model . clone ( ) ,
388+ ) ;
390389 }
391390 }
392391 // NOTE: chat completion chunks should contain only 1 choice
393392 if let Some ( choice) = chat_completion. choices . first ( ) {
394393 // Extract choice text
395394 let choice_text = choice. delta . content . clone ( ) . unwrap_or_default ( ) ;
396- // Update state for this choice index
397- if let Some ( state) = & chat_completion_state {
398- match state. chat_completions . entry ( choice. index ) {
399- dashmap:: Entry :: Occupied ( mut entry) => {
400- entry
401- . get_mut ( )
402- . insert ( message_index, chat_completion. clone ( ) ) ;
403- }
404- dashmap:: Entry :: Vacant ( entry) => {
405- entry. insert ( BTreeMap :: from ( [ (
406- message_index,
407- chat_completion. clone ( ) ,
408- ) ] ) ) ;
409- }
410- }
395+ // Update state: insert completion
396+ if let Some ( state) = & completion_state {
397+ state. insert_completion (
398+ choice. index ,
399+ message_index,
400+ chat_completion. clone ( ) ,
401+ ) ;
411402 }
412403 // Send choice text to detection input channel
413404 if let Some ( input_tx) =
@@ -446,11 +437,11 @@ async fn handle_whole_doc_output_detection(
446437 ctx : Arc < Context > ,
447438 task : & ChatCompletionsDetectionTask ,
448439 detectors : HashMap < String , DetectorParams > ,
449- chat_completion_state : Arc < ChatCompletionState > ,
440+ completion_state : Arc < CompletionState < ChatCompletionChunk > > ,
450441) -> Result < ( OpenAiDetections , Vec < OrchestratorWarning > ) , Error > {
451442 // Create vec of choice_index->inputs, where inputs contains the concatenated text for the choice
452- let choice_inputs = chat_completion_state
453- . chat_completions
443+ let choice_inputs = completion_state
444+ . completions
454445 . iter ( )
455446 . map ( |entry| {
456447 let choice_index = * entry. key ( ) ;
@@ -508,16 +499,13 @@ async fn handle_whole_doc_output_detection(
508499
509500/// Builds a response with output detections.
510501fn output_detection_response (
511- chat_completion_state : & Arc < ChatCompletionState > ,
502+ completion_state : & Arc < CompletionState < ChatCompletionChunk > > ,
512503 choice_index : u32 ,
513504 chunk : Chunk ,
514505 detections : Detections ,
515506) -> Result < ChatCompletionChunk , Error > {
516507 // Get chat completions for this choice index
517- let chat_completions = chat_completion_state
518- . chat_completions
519- . get ( & choice_index)
520- . unwrap ( ) ;
508+ let chat_completions = completion_state. completions . get ( & choice_index) . unwrap ( ) ;
521509 // Get range of chat completions for this chunk
522510 let chat_completions = chat_completions
523511 . range ( chunk. input_start_index ..=chunk. input_end_index )
@@ -581,20 +569,16 @@ fn merge_logprobs(chat_completions: &[ChatCompletionChunk]) -> Option<ChatComple
581569/// Consumes a detection batch stream, builds responses, and sends them to a response channel.
582570async fn process_detection_batch_stream (
583571 trace_id : TraceId ,
584- chat_completion_state : Arc < ChatCompletionState > ,
572+ completion_state : Arc < CompletionState < ChatCompletionChunk > > ,
585573 mut detection_batch_stream : DetectionBatchStream ,
586574 response_tx : mpsc:: Sender < Result < Option < ChatCompletionChunk > , Error > > ,
587575) {
588576 while let Some ( result) = detection_batch_stream. next ( ) . await {
589577 match result {
590578 Ok ( ( choice_index, chunk, detections) ) => {
591579 let input_end_index = chunk. input_end_index ;
592- match output_detection_response (
593- & chat_completion_state,
594- choice_index,
595- chunk,
596- detections,
597- ) {
580+ match output_detection_response ( & completion_state, choice_index, chunk, detections)
581+ {
598582 Ok ( chat_completion) => {
599583 // Send chat completion to response channel
600584 debug ! ( %trace_id, %choice_index, ?chat_completion, "sending chat completion chunk to response channel" ) ;
@@ -603,10 +587,8 @@ async fn process_detection_batch_stream(
603587 return ;
604588 }
605589 // If this is the final chat completion chunk with content, send chat completion chunk with finish reason
606- let chat_completions = chat_completion_state
607- . chat_completions
608- . get ( & choice_index)
609- . unwrap ( ) ;
590+ let chat_completions =
591+ completion_state. completions . get ( & choice_index) . unwrap ( ) ;
610592 if chat_completions. keys ( ) . rev ( ) . nth ( 1 ) == Some ( & input_end_index) {
611593 if let Some ( ( _, chat_completion) ) = chat_completions. last_key_value ( ) {
612594 if chat_completion
@@ -645,45 +627,3 @@ async fn process_detection_batch_stream(
645627 }
646628 info ! ( %trace_id, "task completed: detection batch stream closed" ) ;
647629}
648-
649- #[ derive( Debug , Default ) ]
650- struct ChatCompletionMetadata {
651- /// A unique identifier for the chat completion. Each chunk has the same ID.
652- pub id : String ,
653- /// The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp.
654- pub created : i64 ,
655- /// The model to generate the completion.
656- pub model : String ,
657- /// Completion usage statistics.
658- pub usage : Option < Usage > ,
659- }
660-
661- #[ derive( Debug , Default ) ]
662- struct ChatCompletionState {
663- /// Chat completion metadata.
664- pub metadata : Mutex < ChatCompletionMetadata > ,
665- /// A map of chat completion chunks received for each choice.
666- pub chat_completions : DashMap < ChoiceIndex , BTreeMap < usize , ChatCompletionChunk > > ,
667- }
668-
669- impl ChatCompletionState {
670- pub fn new ( ) -> Self {
671- Self :: default ( )
672- }
673-
674- pub fn id ( & self ) -> String {
675- self . metadata . lock ( ) . unwrap ( ) . id . clone ( )
676- }
677-
678- pub fn created ( & self ) -> i64 {
679- self . metadata . lock ( ) . unwrap ( ) . created
680- }
681-
682- pub fn model ( & self ) -> String {
683- self . metadata . lock ( ) . unwrap ( ) . model . clone ( )
684- }
685-
686- pub fn usage ( & self ) -> Option < Usage > {
687- self . metadata . lock ( ) . unwrap ( ) . usage . clone ( )
688- }
689- }
0 commit comments