Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/orchestrator/handlers/completions_detection/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,12 +615,19 @@ async fn process_detection_batch_stream(
mut detection_batch_stream: DetectionBatchStream,
response_tx: mpsc::Sender<Result<Option<Completion>, Error>>,
) {
let mut batch_tracker: HashMap<u32, Vec<(usize, usize)>> = HashMap::new();
while let Some(result) = detection_batch_stream.next().await {
match result {
Ok((choice_index, chunk, detections)) => {
let indices = (chunk.input_start_index, chunk.input_end_index);
match output_detection_response(&completion_state, choice_index, chunk, detections)
{
Ok(completion) => {
// Record indices for this batch
batch_tracker
.entry(choice_index)
.and_modify(|entry| entry.push(indices))
.or_insert(vec![indices]);
// Send completion to response channel
debug!(%trace_id, %choice_index, ?completion, "sending completion chunk to response channel");
if response_tx.send(Ok(Some(completion))).await.is_err() {
Expand All @@ -644,5 +651,35 @@ async fn process_detection_batch_stream(
}
}
}
// Ensure the last completion chunk including finish_reason is sent for each choice.
//
// An edge case exists where the last completion chunk would not be included in the final batch
// if it has empty choice text. This is because chunks without choice text are not sent to the detection pipeline.
for (choice_index, indices) in batch_tracker {
// Lookup the last completion chunk received
let completions = completion_state.completions.get(&choice_index).unwrap();
let (last_index, completion) = completions
.last_key_value()
.map(|(index, completion)| (*index, completion))
.unwrap();
// Get the index of last completion chunk included in the last batch
let (_start_index, end_index) = indices.last().copied().unwrap();
if last_index != end_index {
// The last batch didn't include the last completion chunk, send it to the response channel
if last_index != end_index + 1 {
warn!(%trace_id, %choice_index, %last_index, %end_index, "unexpected number of completion chunks remaining for choice");
debug!(%trace_id, ?completions);
}
debug!(%trace_id, %choice_index, ?completion, "sending last completion chunk to response channel");
if response_tx
.send(Ok(Some(completion.clone())))
.await
.is_err()
{
info!(%trace_id, "task completed: client disconnected");
return;
}
}
}
info!(%trace_id, "task completed: detection batch stream closed");
}