@@ -23,6 +23,7 @@ use futures::StreamExt;
23
23
use http_body_util:: BodyExt ;
24
24
use hyper:: { HeaderMap , StatusCode } ;
25
25
use serde:: { Deserialize , Serialize } ;
26
+ use serde_json:: { Map , Value } ;
26
27
use tokio:: sync:: mpsc;
27
28
28
29
use super :: {
@@ -32,7 +33,7 @@ use super::{
32
33
use crate :: {
33
34
config:: ServiceConfig ,
34
35
health:: HealthCheckResult ,
35
- models:: { DetectionWarningReason , DetectorParams } ,
36
+ models:: { DetectionWarningReason , DetectorParams , ValidationError } ,
36
37
orchestrator,
37
38
} ;
38
39
@@ -167,122 +168,83 @@ impl From<ChatCompletion> for ChatCompletionsResponse {
167
168
}
168
169
}
169
170
170
- #[ derive( Debug , Default , Clone , Serialize , Deserialize ) ]
171
- #[ serde( deny_unknown_fields) ]
171
+ /// Represents a chat completions request.
172
+ ///
173
+ /// As orchestrator is only concerned with a limited subset
174
+ /// of request fields, we deserialize to an inner [`serde_json::Map`]
175
+ /// and only validate and extract the fields used by this service.
176
+ /// This type is then serialized to the inner [`serde_json::Map`].
177
+ ///
178
+ /// This is to avoid tracking and updating OpenAI and vLLM
179
+ /// parameter additions/changes. Full validation is delegated to
180
+ /// the downstream server implementation.
181
+ ///
182
+ /// Validated fields: detectors (internal), model, messages
183
+ #[ derive( Debug , Default , Clone , PartialEq , Deserialize ) ]
184
+ #[ serde( try_from = "Map<String, Value>" ) ]
172
185
pub struct ChatCompletionsRequest {
173
- /// A list of messages comprising the conversation so far.
174
- pub messages : Vec < Message > ,
175
- /// ID of the model to use.
176
- pub model : String ,
177
- /// Whether or not to store the output of this chat completion request.
178
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
179
- pub store : Option < bool > ,
180
- /// Developer-defined tags and values.
181
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
182
- pub metadata : Option < serde_json:: Value > ,
183
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
184
- pub frequency_penalty : Option < f32 > ,
185
- /// Modify the likelihood of specified tokens appearing in the completion.
186
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
187
- pub logit_bias : Option < HashMap < String , f32 > > ,
188
- /// Whether to return log probabilities of the output tokens or not.
189
- /// If true, returns the log probabilities of each output token returned in the content of message.
190
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
191
- pub logprobs : Option < bool > ,
192
- /// An integer between 0 and 20 specifying the number of most likely tokens to return
193
- /// at each token position, each with an associated log probability.
194
- /// logprobs must be set to true if this parameter is used.
195
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
196
- pub top_logprobs : Option < u32 > ,
197
- /// The maximum number of tokens that can be generated in the chat completion. (DEPRECATED)
198
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
199
- pub max_tokens : Option < u32 > ,
200
- /// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
201
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
202
- pub max_completion_tokens : Option < u32 > ,
203
- /// How many chat completion choices to generate for each input message.
204
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
205
- pub n : Option < u32 > ,
206
- /// Positive values penalize new tokens based on whether they appear in the text so far,
207
- /// increasing the model's likelihood to talk about new topics.
208
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
209
- pub presence_penalty : Option < f32 > ,
210
- /// An object specifying the format that the model must output.
211
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
212
- pub response_format : Option < ResponseFormat > ,
213
- /// If specified, our system will make a best effort to sample deterministically,
214
- /// such that repeated requests with the same seed and parameters should return the same result.
215
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
216
- pub seed : Option < u64 > ,
217
- /// Specifies the latency tier to use for processing the request.
218
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
219
- pub service_tier : Option < String > ,
220
- /// Up to 4 sequences where the API will stop generating further tokens.
221
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
222
- pub stop : Option < StopTokens > ,
223
- /// If set, partial message deltas will be sent, like in ChatGPT.
224
- /// Tokens will be sent as data-only server-sent events as they become available,
225
- /// with the stream terminated by a data: [DONE] message.
226
- #[ serde( default ) ]
186
+ /// Detector config.
187
+ pub detectors : DetectorConfig ,
188
+ /// Stream parameter.
227
189
pub stream : bool ,
228
- /// Options for streaming response. Only set this when you set stream: true.
229
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
230
- pub stream_options : Option < StreamOptions > ,
231
- /// What sampling temperature to use, between 0 and 2.
232
- /// Higher values like 0.8 will make the output more random,
233
- /// while lower values like 0.2 will make it more focused and deterministic.
234
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
235
- pub temperature : Option < f32 > ,
236
- /// An alternative to sampling with temperature, called nucleus sampling,
237
- /// where the model considers the results of the tokens with top_p probability mass.
238
- /// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
239
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
240
- pub top_p : Option < f32 > ,
241
- /// A list of tools the model may call.
242
- #[ serde( default , skip_serializing_if = "Vec::is_empty" ) ]
243
- pub tools : Vec < Tool > ,
244
- /// Controls which (if any) tool is called by the model.
245
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
246
- pub tool_choice : Option < ToolChoice > ,
247
- /// Whether to enable parallel function calling during tool use.
248
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
249
- pub parallel_tool_calls : Option < bool > ,
250
- /// A unique identifier representing your end-user.
251
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
252
- pub user : Option < String > ,
190
+ /// Model name.
191
+ pub model : String ,
192
+ /// Messages.
193
+ pub messages : Vec < Message > ,
194
+ /// Inner request.
195
+ pub inner : Map < String , Value > ,
196
+ }
253
197
254
- // Additional vllm params
255
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
256
- pub best_of : Option < usize > ,
257
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
258
- pub use_beam_search : Option < bool > ,
259
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
260
- pub top_k : Option < isize > ,
261
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
262
- pub min_p : Option < f32 > ,
263
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
264
- pub repetition_penalty : Option < f32 > ,
265
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
266
- pub length_penalty : Option < f32 > ,
267
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
268
- pub early_stopping : Option < bool > ,
269
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
270
- pub ignore_eos : Option < bool > ,
271
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
272
- pub min_tokens : Option < u32 > ,
273
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
274
- pub stop_token_ids : Option < Vec < usize > > ,
275
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
276
- pub skip_special_tokens : Option < bool > ,
277
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
278
- pub spaces_between_special_tokens : Option < bool > ,
198
+ impl TryFrom < Map < String , Value > > for ChatCompletionsRequest {
199
+ type Error = ValidationError ;
279
200
280
- // Detectors
281
- // Note: We are making it optional, since this structure also gets used to
282
- // form request for chat completions. And downstream server, might choose to
283
- // reject extra parameters.
284
- #[ serde( skip_serializing_if = "Option::is_none" ) ]
285
- pub detectors : Option < DetectorConfig > ,
201
+ fn try_from ( mut value : Map < String , Value > ) -> Result < Self , Self :: Error > {
202
+ let detectors = if let Some ( detectors) = value. remove ( "detectors" ) {
203
+ DetectorConfig :: deserialize ( detectors)
204
+ . map_err ( |_| ValidationError :: Invalid ( "error deserializing `detectors`" . into ( ) ) ) ?
205
+ } else {
206
+ DetectorConfig :: default ( )
207
+ } ;
208
+ let stream = value
209
+ . get ( "stream" )
210
+ . and_then ( |v| v. as_bool ( ) )
211
+ . unwrap_or_default ( ) ;
212
+ let model = if let Some ( Value :: String ( model) ) = value. get ( "model" ) {
213
+ Ok ( model. clone ( ) )
214
+ } else {
215
+ Err ( ValidationError :: Required ( "model" . into ( ) ) )
216
+ } ?;
217
+ if model. is_empty ( ) {
218
+ return Err ( ValidationError :: Invalid ( "`model` must not be empty" . into ( ) ) ) ;
219
+ }
220
+ let messages = if let Some ( messages) = value. get ( "messages" ) {
221
+ Vec :: < Message > :: deserialize ( messages)
222
+ . map_err ( |_| ValidationError :: Invalid ( "error deserializing `messages`" . into ( ) ) )
223
+ } else {
224
+ Err ( ValidationError :: Required ( "messages" . into ( ) ) )
225
+ } ?;
226
+ if messages. is_empty ( ) {
227
+ return Err ( ValidationError :: Invalid (
228
+ "`messages` must not be empty" . into ( ) ,
229
+ ) ) ;
230
+ }
231
+ Ok ( ChatCompletionsRequest {
232
+ detectors,
233
+ stream,
234
+ model,
235
+ messages,
236
+ inner : value,
237
+ } )
238
+ }
239
+ }
240
+
241
+ impl Serialize for ChatCompletionsRequest {
242
+ fn serialize < S > ( & self , serializer : S ) -> Result < S :: Ok , S :: Error >
243
+ where
244
+ S : serde:: Serializer ,
245
+ {
246
+ self . inner . serialize ( serializer)
247
+ }
286
248
}
287
249
288
250
/// Structure to contain parameters for detectors.
@@ -291,7 +253,6 @@ pub struct ChatCompletionsRequest {
291
253
pub struct DetectorConfig {
292
254
#[ serde( default , skip_serializing_if = "HashMap::is_empty" ) ]
293
255
pub input : HashMap < String , DetectorParams > ,
294
-
295
256
#[ serde( default , skip_serializing_if = "HashMap::is_empty" ) ]
296
257
pub output : HashMap < String , DetectorParams > ,
297
258
}
@@ -369,7 +330,7 @@ pub enum Role {
369
330
Tool ,
370
331
}
371
332
372
- #[ derive( Debug , Default , Clone , Serialize , Deserialize ) ]
333
+ #[ derive( Debug , Default , Clone , PartialEq , Serialize , Deserialize ) ]
373
334
#[ serde( deny_unknown_fields) ]
374
335
pub struct Message {
375
336
/// The role of the author of this message.
@@ -731,3 +692,103 @@ impl OrchestratorWarning {
731
692
}
732
693
}
733
694
}
695
+
696
+ #[ cfg( test) ]
697
+ mod test {
698
+ use serde_json:: json;
699
+
700
+ use super :: * ;
701
+
702
+ #[ test]
703
+ fn test_chat_completions_request ( ) -> Result < ( ) , serde_json:: Error > {
704
+ // Test deserialize
705
+ let detectors = DetectorConfig {
706
+ input : HashMap :: from ( [ ( "some_detector" . into ( ) , DetectorParams :: new ( ) ) ] ) ,
707
+ output : HashMap :: new ( ) ,
708
+ } ;
709
+ let messages = vec ! [ Message {
710
+ content: Some ( Content :: Text ( "Hi there!" . to_string( ) ) ) ,
711
+ ..Default :: default ( )
712
+ } ] ;
713
+ let json_request = json ! ( {
714
+ "model" : "test" ,
715
+ "detectors" : detectors,
716
+ "messages" : messages,
717
+ } ) ;
718
+ let request = ChatCompletionsRequest :: deserialize ( & json_request) ?;
719
+ let mut inner = json_request. as_object ( ) . unwrap ( ) . to_owned ( ) ;
720
+ inner. remove ( "detectors" ) . unwrap ( ) ;
721
+ assert_eq ! (
722
+ request,
723
+ ChatCompletionsRequest {
724
+ detectors,
725
+ stream: false ,
726
+ model: "test" . into( ) ,
727
+ messages: messages. clone( ) ,
728
+ inner,
729
+ }
730
+ ) ;
731
+
732
+ // Test deserialize with no detectors
733
+ let json_request = json ! ( {
734
+ "model" : "test" ,
735
+ "messages" : messages,
736
+ } ) ;
737
+ let request = ChatCompletionsRequest :: deserialize ( & json_request) ?;
738
+ let inner = json_request. as_object ( ) . unwrap ( ) . to_owned ( ) ;
739
+ assert_eq ! (
740
+ request,
741
+ ChatCompletionsRequest {
742
+ detectors: DetectorConfig :: default ( ) ,
743
+ stream: false ,
744
+ model: "test" . into( ) ,
745
+ messages: messages. clone( ) ,
746
+ inner,
747
+ }
748
+ ) ;
749
+
750
+ // Test deserialize validation errors
751
+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
752
+ "detectors" : DetectorConfig :: default ( ) ,
753
+ "messages" : messages,
754
+ } ) ) ;
755
+ assert ! ( result. is_err_and( |error| error. to_string( ) == "`model` is required" ) ) ;
756
+
757
+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
758
+ "model" : "" ,
759
+ "detectors" : DetectorConfig :: default ( ) ,
760
+ "messages" : Vec :: <Message >:: default ( ) ,
761
+ } ) ) ;
762
+ assert ! ( result. is_err_and( |error| error. to_string( ) == "`model` must not be empty" ) ) ;
763
+
764
+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
765
+ "model" : "test" ,
766
+ "detectors" : DetectorConfig :: default ( ) ,
767
+ "messages" : Vec :: <Message >:: default ( ) ,
768
+ } ) ) ;
769
+ assert ! ( result. is_err_and( |error| error. to_string( ) == "`messages` must not be empty" ) ) ;
770
+
771
+ let result = ChatCompletionsRequest :: deserialize ( json ! ( {
772
+ "model" : "test" ,
773
+ "detectors" : DetectorConfig :: default ( ) ,
774
+ "messages" : [ "invalid" ] ,
775
+ } ) ) ;
776
+ assert ! ( result. is_err_and( |error| error. to_string( ) == "error deserializing `messages`" ) ) ;
777
+
778
+ // Test serialize
779
+ let serialized_request = serde_json:: to_value ( request) ?;
780
+ assert_eq ! (
781
+ serialized_request,
782
+ json!( {
783
+ "model" : "test" ,
784
+ "messages" : [ Message {
785
+ content: Some ( Content :: Text ( "Hi there!" . to_string( ) ) ) ,
786
+ role: Role :: User ,
787
+ ..Default :: default ( )
788
+ } ] ,
789
+ } )
790
+ ) ;
791
+
792
+ Ok ( ( ) )
793
+ }
794
+ }
0 commit comments