Skip to content

Commit 99f5d8f

Browse files
gkumbhatevaline-judeclark1
authored
Add chat content detector (input detection only) (foundation-model-stack#276)
* ✨ Add new datatype for internal conversion of chat content objects * ♻️ Refactor some implementations and add request conversion * ✨ Add detectors config in openai request * 🧑‍💻 Add detection response and warning in openai chat completion response * 🐛 Push detector processing module missed earlier * ♻️ Make certain data models available public * ♻️ Refactor chunking mechanism * 🚧 Iterate over chunking for chat completion content detection * 🐛 Fixed compiler error for iterating over chat messages * 🚧 Add content detection function with datatype transformations * 🚧 Add input detection handling to main function call * 🐛 Fix validation error propagation for filtering messages * 🐛 Fix error handling for chunker * ⚰️ Remove print lines and deadcode * 🐛 Remove orchestrator detection result tagged response * 🐛 Fix detections and results typo * 🔧 Make detector optional in request and add some error handling todos * 🎨 Fix formatting and clippy warnings * ♻️ Refactor part of detection logic for chat completion to allow proper message indexing * :white_check_marks: Add unit test for message filtering * 🚚 Rename filter_chat_messages function to preprocess * 🔧 Fix dummy user id generation and add system to allowed message role * 🔧 Fix according to review suggestions * Apply suggestions from code review Co-authored-by: Evaline Ju <[email protected]> Co-authored-by: Dan Clark <[email protected]> Signed-off-by: Gaurav Kumbhat <[email protected]> * 🎨🚚 Rename and modify optional parameter as per review suggestions * ⚡ Remove need to clone and perform inplace sorting * ⚡✨ Address review suggestions and fixes * Replace warning type from string to predefined type * Add input detection error to warning * Replace detection result sorting to previous method and remove results cloning * Remove extra cloning for filter_chat_message function * Detector message traversal optimizations * Add empty message validation * Apply suggestions from code review Co-authored-by: Dan Clark <[email protected]> Signed-off-by: Gaurav Kumbhat <[email protected]> * 🎨 replace match statement with let-else to rustify * Apply suggestions from code review Co-authored-by: Dan Clark <[email protected]> Signed-off-by: Gaurav Kumbhat <[email protected]> * :white_check_marks: Fix unit test and rename filter message function * ➕ Add uuid in cargo file * ♻️ Refactor request object handling to avoid cloning --------- Signed-off-by: gkumbhat <[email protected]> Signed-off-by: Gaurav Kumbhat <[email protected]> Co-authored-by: Evaline Ju <[email protected]> Co-authored-by: Dan Clark <[email protected]>
1 parent 1dd51fb commit 99f5d8f

File tree

10 files changed

+701
-25
lines changed

10 files changed

+701
-25
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ tracing = "0.1.41"
5555
tracing-opentelemetry = "0.28.0"
5656
tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] }
5757
url = "2.5.4"
58+
uuid = { version = "1.10.0", features = ["v4"] }
5859

5960
[build-dependencies]
6061
tonic-build = "0.12.3"

src/clients/openai.rs

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ use tokio::sync::mpsc;
2727
use tracing::{info, instrument};
2828

2929
use super::{create_http_client, http::HttpClientExt, Client, Error, HttpClient};
30-
use crate::{config::ServiceConfig, health::HealthCheckResult};
30+
use crate::{
31+
config::ServiceConfig,
32+
health::HealthCheckResult,
33+
models::{DetectorParams, GuardrailDetection, InputWarningReason},
34+
};
3135

3236
const DEFAULT_PORT: u16 = 8080;
3337

@@ -157,13 +161,13 @@ impl HttpClientExt for OpenAiClient {
157161

158162
#[derive(Debug)]
159163
pub enum ChatCompletionsResponse {
160-
Unary(ChatCompletion),
164+
Unary(Box<ChatCompletion>),
161165
Streaming(mpsc::Receiver<Result<Option<ChatCompletionChunk>, Error>>),
162166
}
163167

164168
impl From<ChatCompletion> for ChatCompletionsResponse {
165169
fn from(value: ChatCompletion) -> Self {
166-
Self::Unary(value)
170+
Self::Unary(Box::new(value))
167171
}
168172
}
169173

@@ -275,6 +279,23 @@ pub struct ChatCompletionsRequest {
275279
pub skip_special_tokens: Option<bool>,
276280
#[serde(skip_serializing_if = "Option::is_none")]
277281
pub spaces_between_special_tokens: Option<bool>,
282+
283+
// Detectors
284+
// Note: We are making it optional, since this structure also gets used to
285+
// form request for chat completions. And downstream server, might choose to
286+
// reject extra parameters.
287+
#[serde(skip_serializing_if = "Option::is_none")]
288+
pub detectors: Option<DetectorConfig>,
289+
}
290+
291+
/// Structure to contain parameters for detectors.
292+
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
293+
pub struct DetectorConfig {
294+
#[serde(skip_serializing_if = "Option::is_none")]
295+
pub input: Option<HashMap<String, DetectorParams>>,
296+
297+
#[serde(skip_serializing_if = "Option::is_none")]
298+
pub output: Option<HashMap<String, DetectorParams>>,
278299
}
279300

280301
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -386,7 +407,7 @@ pub struct Message {
386407
pub tool_call_id: Option<String>,
387408
}
388409

389-
#[derive(Debug, Clone, Serialize, Deserialize)]
410+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
390411
#[serde(untagged)]
391412
pub enum Content {
392413
/// The text contents of the message.
@@ -430,7 +451,7 @@ impl From<Vec<String>> for Content {
430451
}
431452
}
432453

433-
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
454+
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
434455
pub enum ContentType {
435456
#[serde(rename = "text")]
436457
#[default]
@@ -439,7 +460,7 @@ pub enum ContentType {
439460
ImageUrl,
440461
}
441462

442-
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
463+
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
443464
pub struct ContentPart {
444465
/// The type of the content part.
445466
#[serde(rename = "type")]
@@ -455,7 +476,7 @@ pub struct ContentPart {
455476
pub refusal: Option<String>,
456477
}
457478

458-
#[derive(Debug, Clone, Serialize, Deserialize)]
479+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
459480
pub struct ImageUrl {
460481
/// Either a URL of the image or the base64 encoded image data.
461482
pub url: String,
@@ -485,7 +506,7 @@ pub struct Function {
485506
}
486507

487508
/// Represents a chat completion response returned by model, based on the provided input.
488-
#[derive(Debug, Clone, Serialize, Deserialize)]
509+
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
489510
pub struct ChatCompletion {
490511
/// A unique identifier for the chat completion.
491512
pub id: String,
@@ -506,6 +527,12 @@ pub struct ChatCompletion {
506527
/// This field is only included if the `service_tier` parameter is specified in the request.
507528
#[serde(skip_serializing_if = "Option::is_none")]
508529
pub service_tier: Option<String>,
530+
/// Result of running different guardrail detectors
531+
#[serde(skip_serializing_if = "Option::is_none")]
532+
pub detections: Option<ChatDetections>,
533+
/// Optional warnings
534+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
535+
pub warnings: Vec<OrchestratorWarning>,
509536
}
510537

511538
/// A chat completion choice.
@@ -621,7 +648,7 @@ pub struct ChatCompletionDelta {
621648
}
622649

623650
/// Usage statistics for a completion.
624-
#[derive(Debug, Clone, Serialize, Deserialize)]
651+
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
625652
pub struct Usage {
626653
/// Number of tokens in the prompt.
627654
pub prompt_tokens: u32,
@@ -665,3 +692,44 @@ pub struct OpenAiError {
665692
pub param: Option<String>,
666693
pub code: u16,
667694
}
695+
696+
/// Guardrails detection results.
697+
#[derive(Debug, Clone, Serialize, Deserialize)]
698+
pub struct ChatDetections {
699+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
700+
pub input: Vec<InputDetectionResult>,
701+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
702+
pub output: Vec<OutputDetectionResult>,
703+
}
704+
705+
/// Guardrails detection result for application on input.
706+
#[derive(Debug, Clone, Serialize, Deserialize)]
707+
pub struct InputDetectionResult {
708+
pub message_index: usize,
709+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
710+
pub results: Vec<GuardrailDetection>,
711+
}
712+
713+
/// Guardrails detection result for application output.
714+
#[derive(Debug, Clone, Serialize, Deserialize)]
715+
pub struct OutputDetectionResult {
716+
choice_index: usize,
717+
#[serde(default, skip_serializing_if = "Vec::is_empty")]
718+
results: Vec<GuardrailDetection>,
719+
}
720+
721+
/// Warnings generated by guardrails.
722+
#[derive(Debug, Clone, Serialize, Deserialize)]
723+
pub struct OrchestratorWarning {
724+
r#type: InputWarningReason,
725+
message: String,
726+
}
727+
728+
impl OrchestratorWarning {
729+
pub fn new(warning_type: InputWarningReason, message: &str) -> Self {
730+
Self {
731+
r#type: warning_type,
732+
message: message.to_string(),
733+
}
734+
}
735+
}

src/models.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,13 @@ pub struct DetectionResult {
908908
pub evidence: Option<Vec<EvidenceObj>>,
909909
}
910910

911+
#[derive(Debug, Clone, Serialize, Deserialize)]
912+
#[serde(untagged)]
913+
pub enum GuardrailDetection {
914+
ContentAnalysisResponse(ContentAnalysisResponse),
915+
ClassificationResult(DetectionResult),
916+
}
917+
911918
/// The request format expected in the /api/v2/text/context endpoint.
912919
#[derive(Clone, Debug, Serialize, Deserialize)]
913920
pub struct ContextDocsHttpRequest {

src/orchestrator.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub mod errors;
1919
pub use errors::Error;
2020
use futures::Stream;
2121
pub mod chat_completions_detection;
22+
pub mod detector_processing;
2223
pub mod streaming;
2324
pub mod streaming_content_detection;
2425
pub mod unary;

0 commit comments

Comments
 (0)