Skip to content

Commit 6b7370e

Browse files
committed
Add logic to track batches and send last completion chunk if not sent with final batch
Signed-off-by: declark1 <[email protected]>
1 parent 8dc485a commit 6b7370e

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

src/orchestrator/handlers/completions_detection/streaming.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,12 +615,19 @@ async fn process_detection_batch_stream(
615615
mut detection_batch_stream: DetectionBatchStream,
616616
response_tx: mpsc::Sender<Result<Option<Completion>, Error>>,
617617
) {
618+
let mut batch_tracker: HashMap<u32, Vec<(usize, usize)>> = HashMap::new();
618619
while let Some(result) = detection_batch_stream.next().await {
619620
match result {
620621
Ok((choice_index, chunk, detections)) => {
622+
let indices = (chunk.input_start_index, chunk.input_end_index);
621623
match output_detection_response(&completion_state, choice_index, chunk, detections)
622624
{
623625
Ok(completion) => {
626+
// Record batch indices to tracker
627+
batch_tracker
628+
.entry(choice_index)
629+
.and_modify(|entry| entry.push(indices))
630+
.or_insert(vec![indices]);
624631
// Send completion to response channel
625632
debug!(%trace_id, %choice_index, ?completion, "sending completion chunk to response channel");
626633
if response_tx.send(Ok(Some(completion))).await.is_err() {
@@ -644,5 +651,26 @@ async fn process_detection_batch_stream(
644651
}
645652
}
646653
}
654+
// Ensure last completion chunk with finish_reason is sent for each choice
655+
for (choice_index, indices) in batch_tracker {
656+
// Get last completion chunk
657+
let completions = completion_state.completions.get(&choice_index).unwrap();
658+
let (last_index, completion) = completions
659+
.last_key_value()
660+
.map(|(index, completion)| (*index, completion.clone()))
661+
.unwrap();
662+
let (_start_index, end_index) = indices.last().copied().unwrap();
663+
if last_index != end_index {
664+
if last_index != end_index + 1 {
665+
warn!(%trace_id, %choice_index, %last_index, %end_index, "unexpected number of completion chunks remaining for choice");
666+
debug!(%trace_id, ?completions);
667+
}
668+
debug!(%trace_id, %choice_index, ?completion, "sending last completion chunk to response channel");
669+
if response_tx.send(Ok(Some(completion))).await.is_err() {
670+
info!(%trace_id, "task completed: client disconnected");
671+
return;
672+
}
673+
}
674+
}
647675
info!(%trace_id, "task completed: detection batch stream closed");
648676
}

src/orchestrator/types/completion_state.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ where
6666
pub fn usage(&self) -> Option<&Usage> {
6767
self.usage.get()
6868
}
69+
70+
pub fn keys(&self) -> impl Iterator<Item = u32> {
71+
self.completions.iter().map(|entry| *entry.key())
72+
}
6973
}
7074

7175
/// Completion metadata common to all chunks.

0 commit comments

Comments
 (0)