Skip to content

Commit 700f90e

Browse files
mdevinodeclark1
andauthored
Stream content endpoint (foundation-model-stack#272)
* wip: /stream-content endpoint Signed-off-by: Mateus Devino <[email protected]> * wip: initial streaming and request parsing Signed-off-by: Mateus Devino <[email protected]> * wip: refactor stream error handling Signed-off-by: Mateus Devino <[email protected]> * Improve error handling Signed-off-by: Mateus Devino <[email protected]> * tweaks Signed-off-by: declark1 <[email protected]> * Move handle_streaming_content_detection handler to separate file Signed-off-by: declark1 <[email protected]> * Rename response 'detections' field Signed-off-by: Mateus Devino <[email protected]> * Fix error handling when 'content' is missing on first request Signed-off-by: Mateus Devino <[email protected]> * Fix generation streaming endpoint link Signed-off-by: Mateus Devino <[email protected]> * Fix error handling - second stream message onwards Signed-off-by: Mateus Devino <[email protected]> * Return an actual ndjson Signed-off-by: Mateus Devino <[email protected]> * stream client binary (for testing) Signed-off-by: Mateus Devino <[email protected]> * nd_json for error Signed-off-by: Mateus Devino <[email protected]> * Refactor detector extraction from first frame Signed-off-by: Mateus Devino <[email protected]> * wip: adapt existing code to new endpoint Signed-off-by: Mateus Devino <[email protected]> * Adapt aggregator logic Signed-off-by: Mateus Devino <[email protected]> * Forward results to output stream Signed-off-by: Mateus Devino <[email protected]> * Fix error handling for second frame onward Signed-off-by: Mateus Devino <[email protected]> * Remove unused code Signed-off-by: Mateus Devino <[email protected]> * Rename generation variables Signed-off-by: Mateus Devino <[email protected]> * Re-add request validation Signed-off-by: Mateus Devino <[email protected]> * Remove unused code Signed-off-by: Mateus Devino <[email protected]> * Re-add missing aggregator test Signed-off-by: Mateus Devino <[email protected]> * Remove unneeded derives Signed-off-by: Mateus Devino <[email protected]> * Add copyright note to streaming_content_detection/aggregator.rs Signed-off-by: Mateus Devino <[email protected]> * Nest imports on streaming_content_detection.rs Signed-off-by: Mateus Devino <[email protected]> * Rename request to frame Signed-off-by: Mateus Devino <[email protected]> * Remove unused variable Signed-off-by: Mateus Devino <[email protected]> * Replace errors Debug with Display Signed-off-by: Mateus Devino <[email protected]> * Reuse streaming/aggregator.rs types Signed-off-by: Mateus Devino <[email protected]> * Simplify aggregator logic Signed-off-by: Mateus Devino <[email protected]> * Refactor nd_json conversion logic Signed-off-by: Mateus Devino <[email protected]> * Rename streaming_output_detection_task() Signed-off-by: Mateus Devino <[email protected]> * Drop validation for extra detectors on second frame onwards Signed-off-by: Mateus Devino <[email protected]> * Fix import formatting Signed-off-by: Mateus Devino <[email protected]> * Add comments for a possible refactor Signed-off-by: Mateus Devino <[email protected]> * Replace Tracker Default implementation with a derive macro Signed-off-by: Mateus Devino <[email protected]> --------- Signed-off-by: Mateus Devino <[email protected]> Signed-off-by: declark1 <[email protected]> Co-authored-by: declark1 <[email protected]>
1 parent 0da639d commit 700f90e

File tree

10 files changed

+1055
-11
lines changed

10 files changed

+1055
-11
lines changed

src/models.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,34 @@ pub struct EvidenceObj {
11051105
pub evidence: Option<Vec<Evidence>>,
11061106
}
11071107

1108+
/// Stream content detection stream request
1109+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1110+
#[cfg_attr(test, derive(Default))]
1111+
pub struct StreamingContentDetectionRequest {
1112+
pub detectors: Option<HashMap<String, DetectorParams>>,
1113+
pub content: String,
1114+
}
1115+
1116+
impl StreamingContentDetectionRequest {
1117+
/// validates stream messages
1118+
pub fn validate(&self) -> Result<(), ValidationError> {
1119+
if self.content.is_empty() {
1120+
return Err(ValidationError::Invalid(
1121+
"`content` cannot be empty".to_string(),
1122+
));
1123+
}
1124+
Ok(())
1125+
}
1126+
}
1127+
1128+
/// Stream content detection response
1129+
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
1130+
pub struct StreamingContentDetectionResponse {
1131+
pub detections: Vec<ContentAnalysisResponse>,
1132+
pub processed_index: u32,
1133+
pub start_index: u32,
1134+
}
1135+
11081136
#[cfg(test)]
11091137
mod tests {
11101138
use super::*;

src/orchestrator.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
pub mod errors;
1919
pub use errors::Error;
20+
use futures::Stream;
2021
pub mod chat_completions_detection;
2122
pub mod streaming;
23+
pub mod streaming_content_detection;
2224
pub mod unary;
2325

24-
use std::{collections::HashMap, sync::Arc};
26+
use std::{collections::HashMap, pin::Pin, sync::Arc};
2527

2628
use axum::http::header::HeaderMap;
2729
use opentelemetry::trace::TraceId;
@@ -44,7 +46,8 @@ use crate::{
4446
models::{
4547
ChatDetectionHttpRequest, ContextDocsHttpRequest, DetectionOnGeneratedHttpRequest,
4648
DetectorParams, GenerationWithDetectionHttpRequest, GuardrailsConfig,
47-
GuardrailsHttpRequest, GuardrailsTextGenerationParameters, TextContentDetectionHttpRequest,
49+
GuardrailsHttpRequest, GuardrailsTextGenerationParameters,
50+
StreamingContentDetectionRequest, TextContentDetectionHttpRequest,
4851
},
4952
};
5053

@@ -490,6 +493,31 @@ impl ChatCompletionsDetectionTask {
490493
}
491494
}
492495

496+
pub struct StreamingContentDetectionTask {
497+
pub trace_id: TraceId,
498+
pub headers: HeaderMap,
499+
pub detectors: HashMap<String, DetectorParams>,
500+
pub input_stream:
501+
Pin<Box<dyn Stream<Item = Result<StreamingContentDetectionRequest, Error>> + Send>>,
502+
}
503+
504+
impl StreamingContentDetectionTask {
505+
pub fn new(
506+
trace_id: TraceId,
507+
headers: HeaderMap,
508+
input_stream: Pin<
509+
Box<dyn Stream<Item = Result<StreamingContentDetectionRequest, Error>> + Send>,
510+
>,
511+
) -> Self {
512+
Self {
513+
trace_id,
514+
headers,
515+
detectors: HashMap::default(),
516+
input_stream,
517+
}
518+
}
519+
}
520+
493521
#[cfg(test)]
494522
mod tests {
495523
use super::*;

src/orchestrator/errors.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
limitations under the License.
1515
1616
*/
17-
18-
use crate::clients;
17+
use crate::{clients, models::ValidationError};
1918

2019
/// Orchestrator errors.
2120
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
@@ -34,10 +33,14 @@ pub enum Error {
3433
ChatGenerateRequestFailed { id: String, error: clients::Error },
3534
#[error("tokenize request failed for `{id}`: {error}")]
3635
TokenizeRequestFailed { id: String, error: clients::Error },
36+
#[error("validation error: {0}")]
37+
Validation(String),
3738
#[error("{0}")]
3839
Other(String),
3940
#[error("cancelled")]
4041
Cancelled,
42+
#[error("json deserialization error: {0}")]
43+
JsonError(String),
4144
}
4245

4346
impl From<tokio::task::JoinError> for Error {
@@ -49,3 +52,15 @@ impl From<tokio::task::JoinError> for Error {
4952
}
5053
}
5154
}
55+
56+
impl From<serde_json::Error> for Error {
57+
fn from(value: serde_json::Error) -> Self {
58+
Self::JsonError(value.to_string())
59+
}
60+
}
61+
62+
impl From<ValidationError> for Error {
63+
fn from(value: ValidationError) -> Self {
64+
Self::Validation(value.to_string())
65+
}
66+
}

src/orchestrator/streaming.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
1616
*/
1717

18-
mod aggregator;
18+
pub mod aggregator;
1919

2020
use std::{collections::HashMap, pin::Pin, sync::Arc, time::Duration};
2121

src/orchestrator/streaming/aggregator.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ impl GenerationActorHandle {
370370
}
371371

372372
#[derive(Debug, Clone)]
373-
struct TrackerEntry {
373+
pub struct TrackerEntry {
374374
pub chunk: Chunk,
375375
pub detections: Vec<Detections>,
376376
}
@@ -384,8 +384,8 @@ impl TrackerEntry {
384384
}
385385
}
386386

387-
#[derive(Debug, Clone)]
388-
struct Tracker {
387+
#[derive(Debug, Clone, Default)]
388+
pub struct Tracker {
389389
state: BTreeMap<Span, TrackerEntry>,
390390
}
391391

0 commit comments

Comments
 (0)