@@ -23,6 +23,7 @@ use futures::StreamExt;
2323use http_body_util:: BodyExt ;
2424use hyper:: { HeaderMap , StatusCode } ;
2525use serde:: { Deserialize , Serialize } ;
26+ use serde_json:: { Map , Value } ;
2627use tokio:: sync:: mpsc;
2728
2829use super :: {
@@ -32,7 +33,7 @@ use super::{
3233use crate :: {
3334 config:: ServiceConfig ,
3435 health:: HealthCheckResult ,
35- models:: { DetectionWarningReason , DetectorParams } ,
36+ models:: { DetectionWarningReason , DetectorParams , ValidationError } ,
3637 orchestrator,
3738} ;
3839
@@ -167,122 +168,83 @@ impl From<ChatCompletion> for ChatCompletionsResponse {
167168 }
168169}
169170
170- #[ derive( Debug , Default , Clone , Serialize , Deserialize ) ]
171- #[ serde( deny_unknown_fields) ]
171+ /// Represents a chat completions request.
172+ ///
173+ /// As orchestrator is only concerned with a limited subset
174+ /// of request fields, we deserialize to an inner [`serde_json::Map`]
175+ /// and only validate and extract the fields used by this service.
176+ /// This type is then serialized to the inner [`serde_json::Map`].
177+ ///
178+ /// This is to avoid tracking and updating OpenAI and vLLM
179+ /// parameter additions/changes. Full validation is delegated to
180+ /// the downstream server implementation.
181+ ///
182+ /// Validated fields: detectors (internal), model, messages
183+ #[ derive( Debug , Default , Clone , PartialEq , Deserialize ) ]
184+ #[ serde( try_from = "Map<String, Value>" ) ]
172185pub struct ChatCompletionsRequest {
173- /// A list of messages comprising the conversation so far.
174- pub messages : Vec < Message > ,
175- /// ID of the model to use.
176- pub model : String ,
177- /// Whether or not to store the output of this chat completion request.
178- #[ serde( skip_serializing_if = "Option::is_none" ) ]
179- pub store : Option < bool > ,
180- /// Developer-defined tags and values.
181- #[ serde( skip_serializing_if = "Option::is_none" ) ]
182- pub metadata : Option < serde_json:: Value > ,
183- #[ serde( skip_serializing_if = "Option::is_none" ) ]
184- pub frequency_penalty : Option < f32 > ,
185- /// Modify the likelihood of specified tokens appearing in the completion.
186- #[ serde( skip_serializing_if = "Option::is_none" ) ]
187- pub logit_bias : Option < HashMap < String , f32 > > ,
188- /// Whether to return log probabilities of the output tokens or not.
189- /// If true, returns the log probabilities of each output token returned in the content of message.
190- #[ serde( skip_serializing_if = "Option::is_none" ) ]
191- pub logprobs : Option < bool > ,
192- /// An integer between 0 and 20 specifying the number of most likely tokens to return
193- /// at each token position, each with an associated log probability.
194- /// logprobs must be set to true if this parameter is used.
195- #[ serde( skip_serializing_if = "Option::is_none" ) ]
196- pub top_logprobs : Option < u32 > ,
197- /// The maximum number of tokens that can be generated in the chat completion. (DEPRECATED)
198- #[ serde( skip_serializing_if = "Option::is_none" ) ]
199- pub max_tokens : Option < u32 > ,
200- /// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
201- #[ serde( skip_serializing_if = "Option::is_none" ) ]
202- pub max_completion_tokens : Option < u32 > ,
203- /// How many chat completion choices to generate for each input message.
204- #[ serde( skip_serializing_if = "Option::is_none" ) ]
205- pub n : Option < u32 > ,
206- /// Positive values penalize new tokens based on whether they appear in the text so far,
207- /// increasing the model's likelihood to talk about new topics.
208- #[ serde( skip_serializing_if = "Option::is_none" ) ]
209- pub presence_penalty : Option < f32 > ,
210- /// An object specifying the format that the model must output.
211- #[ serde( skip_serializing_if = "Option::is_none" ) ]
212- pub response_format : Option < ResponseFormat > ,
213- /// If specified, our system will make a best effort to sample deterministically,
214- /// such that repeated requests with the same seed and parameters should return the same result.
215- #[ serde( skip_serializing_if = "Option::is_none" ) ]
216- pub seed : Option < u64 > ,
217- /// Specifies the latency tier to use for processing the request.
218- #[ serde( skip_serializing_if = "Option::is_none" ) ]
219- pub service_tier : Option < String > ,
220- /// Up to 4 sequences where the API will stop generating further tokens.
221- #[ serde( skip_serializing_if = "Option::is_none" ) ]
222- pub stop : Option < StopTokens > ,
223- /// If set, partial message deltas will be sent, like in ChatGPT.
224- /// Tokens will be sent as data-only server-sent events as they become available,
225- /// with the stream terminated by a data: [DONE] message.
226- #[ serde( default ) ]
186+ /// Detector config.
187+ pub detectors : DetectorConfig ,
188+ /// Stream parameter.
227189 pub stream : bool ,
228- /// Options for streaming response. Only set this when you set stream: true.
229- #[ serde( skip_serializing_if = "Option::is_none" ) ]
230- pub stream_options : Option < StreamOptions > ,
231- /// What sampling temperature to use, between 0 and 2.
232- /// Higher values like 0.8 will make the output more random,
233- /// while lower values like 0.2 will make it more focused and deterministic.
234- #[ serde( skip_serializing_if = "Option::is_none" ) ]
235- pub temperature : Option < f32 > ,
236- /// An alternative to sampling with temperature, called nucleus sampling,
237- /// where the model considers the results of the tokens with top_p probability mass.
238- /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
239- #[ serde( skip_serializing_if = "Option::is_none" ) ]
240- pub top_p : Option < f32 > ,
241- /// A list of tools the model may call.
242- #[ serde( default , skip_serializing_if = "Vec::is_empty" ) ]
243- pub tools : Vec < Tool > ,
244- /// Controls which (if any) tool is called by the model.
245- #[ serde( skip_serializing_if = "Option::is_none" ) ]
246- pub tool_choice : Option < ToolChoice > ,
247- /// Whether to enable parallel function calling during tool use.
248- #[ serde( skip_serializing_if = "Option::is_none" ) ]
249- pub parallel_tool_calls : Option < bool > ,
250- /// A unique identifier representing your end-user.
251- #[ serde( skip_serializing_if = "Option::is_none" ) ]
252- pub user : Option < String > ,
190+ /// Model name.
191+ pub model : String ,
192+ /// Messages.
193+ pub messages : Vec < Message > ,
194+ /// Inner request.
195+ pub inner : Map < String , Value > ,
196+ }
253197
254- // Additional vllm params
255- #[ serde( skip_serializing_if = "Option::is_none" ) ]
256- pub best_of : Option < usize > ,
257- #[ serde( skip_serializing_if = "Option::is_none" ) ]
258- pub use_beam_search : Option < bool > ,
259- #[ serde( skip_serializing_if = "Option::is_none" ) ]
260- pub top_k : Option < isize > ,
261- #[ serde( skip_serializing_if = "Option::is_none" ) ]
262- pub min_p : Option < f32 > ,
263- #[ serde( skip_serializing_if = "Option::is_none" ) ]
264- pub repetition_penalty : Option < f32 > ,
265- #[ serde( skip_serializing_if = "Option::is_none" ) ]
266- pub length_penalty : Option < f32 > ,
267- #[ serde( skip_serializing_if = "Option::is_none" ) ]
268- pub early_stopping : Option < bool > ,
269- #[ serde( skip_serializing_if = "Option::is_none" ) ]
270- pub ignore_eos : Option < bool > ,
271- #[ serde( skip_serializing_if = "Option::is_none" ) ]
272- pub min_tokens : Option < u32 > ,
273- #[ serde( skip_serializing_if = "Option::is_none" ) ]
274- pub stop_token_ids : Option < Vec < usize > > ,
275- #[ serde( skip_serializing_if = "Option::is_none" ) ]
276- pub skip_special_tokens : Option < bool > ,
277- #[ serde( skip_serializing_if = "Option::is_none" ) ]
278- pub spaces_between_special_tokens : Option < bool > ,
198+ impl TryFrom < Map < String , Value > > for ChatCompletionsRequest {
199+ type Error = ValidationError ;
279200
280- // Detectors
281- // Note: We are making it optional, since this structure also gets used to
282- // form request for chat completions. And downstream server, might choose to
283- // reject extra parameters.
284- #[ serde( skip_serializing_if = "Option::is_none" ) ]
285- pub detectors : Option < DetectorConfig > ,
201+ fn try_from ( mut value : Map < String , Value > ) -> Result < Self , Self :: Error > {
202+ let detectors = if let Some ( detectors) = value. remove ( "detectors" ) {
203+ DetectorConfig :: deserialize ( detectors)
204+ . map_err ( |_| ValidationError :: Invalid ( "error deserializing `detectors`" . into ( ) ) ) ?
205+ } else {
206+ DetectorConfig :: default ( )
207+ } ;
208+ let stream = value
209+ . get ( "stream" )
210+ . and_then ( |v| v. as_bool ( ) )
211+ . unwrap_or_default ( ) ;
212+ let model = if let Some ( Value :: String ( model) ) = value. get ( "model" ) {
213+ Ok ( model. clone ( ) )
214+ } else {
215+ Err ( ValidationError :: Required ( "model" . into ( ) ) )
216+ } ?;
217+ if model. is_empty ( ) {
218+ return Err ( ValidationError :: Invalid ( "`model` must not be empty" . into ( ) ) ) ;
219+ }
220+ let messages = if let Some ( messages) = value. get ( "messages" ) {
221+ Vec :: < Message > :: deserialize ( messages)
222+ . map_err ( |_| ValidationError :: Invalid ( "error deserializing `messages`" . into ( ) ) )
223+ } else {
224+ Err ( ValidationError :: Required ( "messages" . into ( ) ) )
225+ } ?;
226+ if messages. is_empty ( ) {
227+ return Err ( ValidationError :: Invalid (
228+ "`messages` must not be empty" . into ( ) ,
229+ ) ) ;
230+ }
231+ Ok ( ChatCompletionsRequest {
232+ detectors,
233+ stream,
234+ model,
235+ messages,
236+ inner : value,
237+ } )
238+ }
239+ }
240+
241+ impl Serialize for ChatCompletionsRequest {
242+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
243+ where
244+ S : serde:: Serializer ,
245+ {
246+ self . inner . serialize ( serializer)
247+ }
286248}
287249
288250/// Structure to contain parameters for detectors.
@@ -291,7 +253,6 @@ pub struct ChatCompletionsRequest {
291253pub struct DetectorConfig {
292254 #[ serde( default , skip_serializing_if = "HashMap::is_empty" ) ]
293255 pub input : HashMap < String , DetectorParams > ,
294-
295256 #[ serde( default , skip_serializing_if = "HashMap::is_empty" ) ]
296257 pub output : HashMap < String , DetectorParams > ,
297258}
@@ -369,7 +330,7 @@ pub enum Role {
369330 Tool ,
370331}
371332
372- #[ derive( Debug , Default , Clone , Serialize , Deserialize ) ]
333+ #[ derive( Debug , Default , Clone , PartialEq , Serialize , Deserialize ) ]
373334#[ serde( deny_unknown_fields) ]
374335pub struct Message {
375336 /// The role of the author of this message.
@@ -731,3 +692,103 @@ impl OrchestratorWarning {
731692 }
732693 }
733694}
695+
696+ #[ cfg( test) ]
697+ mod test {
698+ use serde_json:: json;
699+
700+ use super :: * ;
701+
702+ #[ test]
703+ fn test_chat_completions_request ( ) -> Result < ( ) , serde_json:: Error > {
704+ // Test deserialize
705+ let detectors = DetectorConfig {
706+ input : HashMap :: from ( [ ( "some_detector" . into ( ) , DetectorParams :: new ( ) ) ] ) ,
707+ output : HashMap :: new ( ) ,
708+ } ;
709+ let messages = vec ! [ Message {
710+ content: Some ( Content :: Text ( "Hi there!" . to_string( ) ) ) ,
711+ ..Default :: default ( )
712+ } ] ;
713+ let json_request = json ! ( {
714+ "model" : "test" ,
715+ "detectors" : detectors,
716+ "messages" : messages,
717+ } ) ;
718+ let request = ChatCompletionsRequest :: deserialize ( & json_request) ?;
719+ let mut inner = json_request. as_object ( ) . unwrap ( ) . to_owned ( ) ;
720+ inner. remove ( "detectors" ) . unwrap ( ) ;
721+ assert_eq ! (
722+ request,
723+ ChatCompletionsRequest {
724+ detectors,
725+ stream: false ,
726+ model: "test" . into( ) ,
727+ messages: messages. clone( ) ,
728+ inner,
729+ }
730+ ) ;
731+
732+ // Test deserialize with no detectors
733+ let json_request = json ! ( {
734+ "model" : "test" ,
735+ "messages" : messages,
736+ } ) ;
737+ let request = ChatCompletionsRequest :: deserialize ( & json_request) ?;
738+ let inner = json_request. as_object ( ) . unwrap ( ) . to_owned ( ) ;
739+ assert_eq ! (
740+ request,
741+ ChatCompletionsRequest {
742+ detectors: DetectorConfig :: default ( ) ,
743+ stream: false ,
744+ model: "test" . into( ) ,
745+ messages: messages. clone( ) ,
746+ inner,
747+ }
748+ ) ;
749+
750+ // Test deserialize validation errors
751+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
752+ "detectors" : DetectorConfig :: default ( ) ,
753+ "messages" : messages,
754+ } ) ) ;
755+ assert ! ( result. is_err_and( |error| error. to_string( ) == "`model` is required" ) ) ;
756+
757+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
758+ "model" : "" ,
759+ "detectors" : DetectorConfig :: default ( ) ,
760+ "messages" : Vec :: <Message >:: default ( ) ,
761+ } ) ) ;
762+ assert ! ( result. is_err_and( |error| error. to_string( ) == "`model` must not be empty" ) ) ;
763+
764+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
765+ "model" : "test" ,
766+ "detectors" : DetectorConfig :: default ( ) ,
767+ "messages" : Vec :: <Message >:: default ( ) ,
768+ } ) ) ;
769+ assert ! ( result. is_err_and( |error| error. to_string( ) == "`messages` must not be empty" ) ) ;
770+
771+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
772+ "model" : "test" ,
773+ "detectors" : DetectorConfig :: default ( ) ,
774+ "messages" : [ "invalid" ] ,
775+ } ) ) ;
776+ assert ! ( result. is_err_and( |error| error. to_string( ) == "error deserializing `messages`" ) ) ;
777+
778+ // Test serialize
779+ let serialized_request = serde_json:: to_value ( request) ?;
780+ assert_eq ! (
781+ serialized_request,
782+ json!( {
783+ "model" : "test" ,
784+ "messages" : [ Message {
785+ content: Some ( Content :: Text ( "Hi there!" . to_string( ) ) ) ,
786+ role: Role :: User ,
787+ ..Default :: default ( )
788+ } ] ,
789+ } )
790+ ) ;
791+
792+ Ok ( ( ) )
793+ }
794+ }
0 commit comments