Skip to content

Commit 9a3fa0c

Browse files
authored
Add Chat Completions Streaming support for continuous_usage_stats (foundation-model-stack#456)
* Move usage to ChatCompletionState, update condition to set final usage state, add output_detectors_with_continuous_usage_stats test Signed-off-by: declark1 <[email protected]> * Add CompletionState, drop ChatCompletionState and update existing usage Signed-off-by: declark1 <[email protected]> --------- Signed-off-by: declark1 <[email protected]>
1 parent 3164f4a commit 9a3fa0c

File tree

4 files changed

+765
-112
lines changed

4 files changed

+765
-112
lines changed

src/orchestrator/handlers/chat_completions_detection/streaming.rs

Lines changed: 44 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,8 @@
1414
limitations under the License.
1515
1616
*/
17-
use std::{
18-
collections::{BTreeMap, HashMap},
19-
sync::{Arc, Mutex},
20-
};
17+
use std::{collections::HashMap, sync::Arc};
2118

22-
use dashmap::DashMap;
2319
use futures::{StreamExt, TryStreamExt, stream};
2420
use opentelemetry::trace::TraceId;
2521
use tokio::sync::mpsc;
@@ -37,8 +33,8 @@ use crate::{
3733
Context, Error,
3834
common::{self, text_contents_detections, validate_detectors},
3935
types::{
40-
ChatCompletionBatcher, ChatCompletionStream, ChatMessageIterator, ChoiceIndex, Chunk,
41-
DetectionBatchStream, Detections,
36+
ChatCompletionBatcher, ChatCompletionStream, ChatMessageIterator, Chunk,
37+
CompletionState, DetectionBatchStream, Detections,
4238
},
4339
},
4440
};
@@ -237,7 +233,7 @@ async fn handle_output_detection(
237233
detectors.into_iter().partition(|(detector_id, _)| {
238234
ctx.config.get_chunker_id(detector_id).unwrap() == "whole_doc_chunker"
239235
});
240-
let chat_completion_state = Arc::new(ChatCompletionState::new());
236+
let completion_state = Arc::new(CompletionState::new());
241237

242238
if !detectors.is_empty() {
243239
// Set up streaming detection pipeline
@@ -279,7 +275,7 @@ async fn handle_output_detection(
279275
tokio::spawn(process_chat_completion_stream(
280276
trace_id,
281277
chat_completion_stream,
282-
Some(chat_completion_state.clone()),
278+
Some(completion_state.clone()),
283279
Some(input_txs),
284280
None,
285281
));
@@ -290,7 +286,7 @@ async fn handle_output_detection(
290286
);
291287
process_detection_batch_stream(
292288
trace_id,
293-
chat_completion_state.clone(),
289+
completion_state.clone(),
294290
detection_batch_stream,
295291
response_tx.clone(),
296292
)
@@ -301,7 +297,7 @@ async fn handle_output_detection(
301297
process_chat_completion_stream(
302298
trace_id,
303299
chat_completion_stream,
304-
Some(chat_completion_state.clone()),
300+
Some(completion_state.clone()),
305301
None,
306302
Some(response_tx.clone()),
307303
)
@@ -310,12 +306,12 @@ async fn handle_output_detection(
310306
// NOTE: at this point, the chat completions stream has been fully consumed and chat completion state is final
311307

312308
// If whole doc output detections or usage is requested, a final message is sent with these items
313-
if !whole_doc_detectors.is_empty() || chat_completion_state.usage().is_some() {
309+
if !whole_doc_detectors.is_empty() || completion_state.usage().is_some() {
314310
let mut chat_completion = ChatCompletionChunk {
315-
id: chat_completion_state.id(),
316-
created: chat_completion_state.created(),
317-
model: chat_completion_state.model(),
318-
usage: chat_completion_state.usage(),
311+
id: completion_state.id().unwrap().to_string(),
312+
created: completion_state.created().unwrap(),
313+
model: completion_state.model().unwrap().to_string(),
314+
usage: completion_state.usage().cloned(),
319315
..Default::default()
320316
};
321317
if !whole_doc_detectors.is_empty() {
@@ -324,7 +320,7 @@ async fn handle_output_detection(
324320
ctx.clone(),
325321
task,
326322
whole_doc_detectors,
327-
chat_completion_state,
323+
completion_state,
328324
)
329325
.await
330326
{
@@ -352,7 +348,7 @@ async fn handle_output_detection(
352348
async fn process_chat_completion_stream(
353349
trace_id: TraceId,
354350
mut chat_completion_stream: ChatCompletionStream,
355-
chat_completion_state: Option<Arc<ChatCompletionState>>,
351+
completion_state: Option<Arc<CompletionState<ChatCompletionChunk>>>,
356352
input_txs: Option<HashMap<u32, mpsc::Sender<Result<(usize, String), Error>>>>,
357353
response_tx: Option<mpsc::Sender<Result<Option<ChatCompletionChunk>, Error>>>,
358354
) {
@@ -372,42 +368,37 @@ async fn process_chat_completion_stream(
372368
return;
373369
}
374370
}
375-
if chat_completion.usage.is_some() {
376-
// Set usage state from the usage message
371+
if let Some(usage) = &chat_completion.usage
372+
&& chat_completion.choices.is_empty()
373+
{
374+
// Update state: set usage
377375
// NOTE: this message has no choices and is not sent to detection input channel
378-
if let Some(state) = &chat_completion_state {
379-
state.metadata.lock().unwrap().usage = chat_completion.usage.clone();
376+
if let Some(state) = &completion_state {
377+
state.set_usage(usage.clone());
380378
}
381379
} else {
382380
if message_index == 0 {
383-
// Set metadata state from the first message
381+
// Update state: set metadata
384382
// NOTE: these values are the same for all chat completion chunks
385-
if let Some(state) = &chat_completion_state {
386-
let mut metadata = state.metadata.lock().unwrap();
387-
metadata.id = chat_completion.id.clone();
388-
metadata.created = chat_completion.created;
389-
metadata.model = chat_completion.model.clone();
383+
if let Some(state) = &completion_state {
384+
state.set_metadata(
385+
chat_completion.id.clone(),
386+
chat_completion.created,
387+
chat_completion.model.clone(),
388+
);
390389
}
391390
}
392391
// NOTE: chat completion chunks should contain only 1 choice
393392
if let Some(choice) = chat_completion.choices.first() {
394393
// Extract choice text
395394
let choice_text = choice.delta.content.clone().unwrap_or_default();
396-
// Update state for this choice index
397-
if let Some(state) = &chat_completion_state {
398-
match state.chat_completions.entry(choice.index) {
399-
dashmap::Entry::Occupied(mut entry) => {
400-
entry
401-
.get_mut()
402-
.insert(message_index, chat_completion.clone());
403-
}
404-
dashmap::Entry::Vacant(entry) => {
405-
entry.insert(BTreeMap::from([(
406-
message_index,
407-
chat_completion.clone(),
408-
)]));
409-
}
410-
}
395+
// Update state: insert completion
396+
if let Some(state) = &completion_state {
397+
state.insert_completion(
398+
choice.index,
399+
message_index,
400+
chat_completion.clone(),
401+
);
411402
}
412403
// Send choice text to detection input channel
413404
if let Some(input_tx) =
@@ -446,11 +437,11 @@ async fn handle_whole_doc_output_detection(
446437
ctx: Arc<Context>,
447438
task: &ChatCompletionsDetectionTask,
448439
detectors: HashMap<String, DetectorParams>,
449-
chat_completion_state: Arc<ChatCompletionState>,
440+
completion_state: Arc<CompletionState<ChatCompletionChunk>>,
450441
) -> Result<(OpenAiDetections, Vec<OrchestratorWarning>), Error> {
451442
// Create vec of choice_index->inputs, where inputs contains the concatenated text for the choice
452-
let choice_inputs = chat_completion_state
453-
.chat_completions
443+
let choice_inputs = completion_state
444+
.completions
454445
.iter()
455446
.map(|entry| {
456447
let choice_index = *entry.key();
@@ -508,16 +499,13 @@ async fn handle_whole_doc_output_detection(
508499

509500
/// Builds a response with output detections.
510501
fn output_detection_response(
511-
chat_completion_state: &Arc<ChatCompletionState>,
502+
completion_state: &Arc<CompletionState<ChatCompletionChunk>>,
512503
choice_index: u32,
513504
chunk: Chunk,
514505
detections: Detections,
515506
) -> Result<ChatCompletionChunk, Error> {
516507
// Get chat completions for this choice index
517-
let chat_completions = chat_completion_state
518-
.chat_completions
519-
.get(&choice_index)
520-
.unwrap();
508+
let chat_completions = completion_state.completions.get(&choice_index).unwrap();
521509
// Get range of chat completions for this chunk
522510
let chat_completions = chat_completions
523511
.range(chunk.input_start_index..=chunk.input_end_index)
@@ -581,20 +569,16 @@ fn merge_logprobs(chat_completions: &[ChatCompletionChunk]) -> Option<ChatComple
581569
/// Consumes a detection batch stream, builds responses, and sends them to a response channel.
582570
async fn process_detection_batch_stream(
583571
trace_id: TraceId,
584-
chat_completion_state: Arc<ChatCompletionState>,
572+
completion_state: Arc<CompletionState<ChatCompletionChunk>>,
585573
mut detection_batch_stream: DetectionBatchStream,
586574
response_tx: mpsc::Sender<Result<Option<ChatCompletionChunk>, Error>>,
587575
) {
588576
while let Some(result) = detection_batch_stream.next().await {
589577
match result {
590578
Ok((choice_index, chunk, detections)) => {
591579
let input_end_index = chunk.input_end_index;
592-
match output_detection_response(
593-
&chat_completion_state,
594-
choice_index,
595-
chunk,
596-
detections,
597-
) {
580+
match output_detection_response(&completion_state, choice_index, chunk, detections)
581+
{
598582
Ok(chat_completion) => {
599583
// Send chat completion to response channel
600584
debug!(%trace_id, %choice_index, ?chat_completion, "sending chat completion chunk to response channel");
@@ -603,10 +587,8 @@ async fn process_detection_batch_stream(
603587
return;
604588
}
605589
// If this is the final chat completion chunk with content, send chat completion chunk with finish reason
606-
let chat_completions = chat_completion_state
607-
.chat_completions
608-
.get(&choice_index)
609-
.unwrap();
590+
let chat_completions =
591+
completion_state.completions.get(&choice_index).unwrap();
610592
if chat_completions.keys().rev().nth(1) == Some(&input_end_index) {
611593
if let Some((_, chat_completion)) = chat_completions.last_key_value() {
612594
if chat_completion
@@ -645,45 +627,3 @@ async fn process_detection_batch_stream(
645627
}
646628
info!(%trace_id, "task completed: detection batch stream closed");
647629
}
648-
649-
#[derive(Debug, Default)]
650-
struct ChatCompletionMetadata {
651-
/// A unique identifier for the chat completion. Each chunk has the same ID.
652-
pub id: String,
653-
/// The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp.
654-
pub created: i64,
655-
/// The model to generate the completion.
656-
pub model: String,
657-
/// Completion usage statistics.
658-
pub usage: Option<Usage>,
659-
}
660-
661-
#[derive(Debug, Default)]
662-
struct ChatCompletionState {
663-
/// Chat completion metadata.
664-
pub metadata: Mutex<ChatCompletionMetadata>,
665-
/// A map of chat completion chunks received for each choice.
666-
pub chat_completions: DashMap<ChoiceIndex, BTreeMap<usize, ChatCompletionChunk>>,
667-
}
668-
669-
impl ChatCompletionState {
670-
pub fn new() -> Self {
671-
Self::default()
672-
}
673-
674-
pub fn id(&self) -> String {
675-
self.metadata.lock().unwrap().id.clone()
676-
}
677-
678-
pub fn created(&self) -> i64 {
679-
self.metadata.lock().unwrap().created
680-
}
681-
682-
pub fn model(&self) -> String {
683-
self.metadata.lock().unwrap().model.clone()
684-
}
685-
686-
pub fn usage(&self) -> Option<Usage> {
687-
self.metadata.lock().unwrap().usage.clone()
688-
}
689-
}

src/orchestrator/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ pub mod detection_batcher;
2828
pub use detection_batcher::*;
2929
pub mod detection_batch_stream;
3030
pub use detection_batch_stream::*;
31+
pub mod completion_state;
32+
pub use completion_state::*;
3133

3234
use super::Error;
3335
use crate::{
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use std::{collections::BTreeMap, sync::OnceLock};
2+
3+
use dashmap::DashMap;
4+
5+
use super::ChoiceIndex;
6+
use crate::clients::openai::Usage;
7+
8+
/// Completion state for a streaming completions task.
9+
#[derive(Debug, Default)]
10+
pub struct CompletionState<T> {
11+
/// Completion metadata.
12+
pub metadata: OnceLock<CompletionMetadata>,
13+
/// Completion chunks received for each choice.
14+
pub completions: DashMap<ChoiceIndex, BTreeMap<usize, T>>,
15+
/// Completion usage statistics.
16+
pub usage: OnceLock<Usage>,
17+
}
18+
19+
impl<T> CompletionState<T>
20+
where
21+
T: Default,
22+
{
23+
pub fn new() -> Self {
24+
Self::default()
25+
}
26+
27+
/// Sets metadata.
28+
pub fn set_metadata(&self, id: String, created: i64, model: String) {
29+
let _ = self.metadata.set(CompletionMetadata { id, created, model });
30+
}
31+
32+
/// Sets usage.
33+
pub fn set_usage(&self, usage: Usage) {
34+
let _ = self.usage.set(usage);
35+
}
36+
37+
/// Inserts a completion.
38+
pub fn insert_completion(
39+
&self,
40+
choice_index: ChoiceIndex,
41+
message_index: usize,
42+
completion: T,
43+
) {
44+
match self.completions.entry(choice_index) {
45+
dashmap::Entry::Occupied(mut entry) => {
46+
entry.get_mut().insert(message_index, completion);
47+
}
48+
dashmap::Entry::Vacant(entry) => {
49+
entry.insert(BTreeMap::from([(message_index, completion)]));
50+
}
51+
}
52+
}
53+
54+
pub fn id(&self) -> Option<&str> {
55+
self.metadata.get().map(|v| v.id.as_ref())
56+
}
57+
58+
pub fn created(&self) -> Option<i64> {
59+
self.metadata.get().map(|v| v.created)
60+
}
61+
62+
pub fn model(&self) -> Option<&str> {
63+
self.metadata.get().map(|v| v.model.as_ref())
64+
}
65+
66+
pub fn usage(&self) -> Option<&Usage> {
67+
self.usage.get()
68+
}
69+
}
70+
71+
/// Completion metadata common to all chunks.
72+
#[derive(Debug, Default)]
73+
pub struct CompletionMetadata {
74+
/// A unique identifier for the completion.
75+
pub id: String,
76+
/// The Unix timestamp (in seconds) of when the completion was created.
77+
pub created: i64,
78+
/// The model to generate the completion.
79+
pub model: String,
80+
}

0 commit comments

Comments
 (0)