@@ -615,12 +615,19 @@ async fn process_detection_batch_stream(
615615 mut detection_batch_stream : DetectionBatchStream ,
616616 response_tx : mpsc:: Sender < Result < Option < Completion > , Error > > ,
617617) {
618+ let mut batch_tracker: HashMap < u32 , Vec < ( usize , usize ) > > = HashMap :: new ( ) ;
618619 while let Some ( result) = detection_batch_stream. next ( ) . await {
619620 match result {
620621 Ok ( ( choice_index, chunk, detections) ) => {
622+ let indices = ( chunk. input_start_index , chunk. input_end_index ) ;
621623 match output_detection_response ( & completion_state, choice_index, chunk, detections)
622624 {
623625 Ok ( completion) => {
626+ // Record batch indices to tracker
627+ batch_tracker
628+ . entry ( choice_index)
629+ . and_modify ( |entry| entry. push ( indices) )
630+ . or_insert ( vec ! [ indices] ) ;
624631 // Send completion to response channel
625632 debug ! ( %trace_id, %choice_index, ?completion, "sending completion chunk to response channel" ) ;
626633 if response_tx. send ( Ok ( Some ( completion) ) ) . await . is_err ( ) {
@@ -644,5 +651,26 @@ async fn process_detection_batch_stream(
644651 }
645652 }
646653 }
654+ // Ensure last completion chunk with finish_reason is sent for each choice
655+ for ( choice_index, indices) in batch_tracker {
656+ // Get last completion chunk
657+ let completions = completion_state. completions . get ( & choice_index) . unwrap ( ) ;
658+ let ( last_index, completion) = completions
659+ . last_key_value ( )
660+ . map ( |( index, completion) | ( * index, completion. clone ( ) ) )
661+ . unwrap ( ) ;
662+ let ( _start_index, end_index) = indices. last ( ) . copied ( ) . unwrap ( ) ;
663+ if last_index != end_index {
664+ if last_index != end_index + 1 {
665+ warn ! ( %trace_id, %choice_index, %last_index, %end_index, "unexpected number of completion chunks remaining for choice" ) ;
666+ debug ! ( %trace_id, ?completions) ;
667+ }
668+ debug ! ( %trace_id, %choice_index, ?completion, "sending last completion chunk to response channel" ) ;
669+ if response_tx. send ( Ok ( Some ( completion) ) ) . await . is_err ( ) {
670+ info ! ( %trace_id, "task completed: client disconnected" ) ;
671+ return ;
672+ }
673+ }
674+ }
647675 info ! ( %trace_id, "task completed: detection batch stream closed" ) ;
648676}
0 commit comments