Skip to content

Commit 6bc045f

Browse files
declark1mdevino
andauthored
refactor ChatCompletionsRequest (#375)
Signed-off-by: declark1 <44146800+declark1@users.noreply.github.com> Co-authored-by: Mateus Devino <19861348+mdevino@users.noreply.github.com>
1 parent 9253698 commit 6bc045f

File tree

5 files changed

+341
-332
lines changed

5 files changed

+341
-332
lines changed

src/clients/openai.rs

+176-115
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use futures::StreamExt;
2323
use http_body_util::BodyExt;
2424
use hyper::{HeaderMap, StatusCode};
2525
use serde::{Deserialize, Serialize};
26+
use serde_json::{Map, Value};
2627
use tokio::sync::mpsc;
2728

2829
use super::{
@@ -32,7 +33,7 @@ use super::{
3233
use crate::{
3334
config::ServiceConfig,
3435
health::HealthCheckResult,
35-
models::{DetectionWarningReason, DetectorParams},
36+
models::{DetectionWarningReason, DetectorParams, ValidationError},
3637
orchestrator,
3738
};
3839

@@ -167,122 +168,83 @@ impl From<ChatCompletion> for ChatCompletionsResponse {
167168
}
168169
}
169170

170-
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
171-
#[serde(deny_unknown_fields)]
171+
/// Represents a chat completions request.
172+
///
173+
/// As orchestrator is only concerned with a limited subset
174+
/// of request fields, we deserialize to an inner [`serde_json::Map`]
175+
/// and only validate and extract the fields used by this service.
176+
/// This type is then serialized to the inner [`serde_json::Map`].
177+
///
178+
/// This is to avoid tracking and updating OpenAI and vLLM
179+
/// parameter additions/changes. Full validation is delegated to
180+
/// the downstream server implementation.
181+
///
182+
/// Validated fields: detectors (internal), model, messages
183+
#[derive(Debug, Default, Clone, PartialEq, Deserialize)]
184+
#[serde(try_from = "Map<String, Value>")]
172185
pub struct ChatCompletionsRequest {
173-
/// A list of messages comprising the conversation so far.
174-
pub messages: Vec<Message>,
175-
/// ID of the model to use.
176-
pub model: String,
177-
/// Whether or not to store the output of this chat completion request.
178-
#[serde(skip_serializing_if = "Option::is_none")]
179-
pub store: Option<bool>,
180-
/// Developer-defined tags and values.
181-
#[serde(skip_serializing_if = "Option::is_none")]
182-
pub metadata: Option<serde_json::Value>,
183-
#[serde(skip_serializing_if = "Option::is_none")]
184-
pub frequency_penalty: Option<f32>,
185-
/// Modify the likelihood of specified tokens appearing in the completion.
186-
#[serde(skip_serializing_if = "Option::is_none")]
187-
pub logit_bias: Option<HashMap<String, f32>>,
188-
/// Whether to return log probabilities of the output tokens or not.
189-
/// If true, returns the log probabilities of each output token returned in the content of message.
190-
#[serde(skip_serializing_if = "Option::is_none")]
191-
pub logprobs: Option<bool>,
192-
/// An integer between 0 and 20 specifying the number of most likely tokens to return
193-
/// at each token position, each with an associated log probability.
194-
/// logprobs must be set to true if this parameter is used.
195-
#[serde(skip_serializing_if = "Option::is_none")]
196-
pub top_logprobs: Option<u32>,
197-
/// The maximum number of tokens that can be generated in the chat completion. (DEPRECATED)
198-
#[serde(skip_serializing_if = "Option::is_none")]
199-
pub max_tokens: Option<u32>,
200-
/// An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and reasoning tokens.
201-
#[serde(skip_serializing_if = "Option::is_none")]
202-
pub max_completion_tokens: Option<u32>,
203-
/// How many chat completion choices to generate for each input message.
204-
#[serde(skip_serializing_if = "Option::is_none")]
205-
pub n: Option<u32>,
206-
/// Positive values penalize new tokens based on whether they appear in the text so far,
207-
/// increasing the model's likelihood to talk about new topics.
208-
#[serde(skip_serializing_if = "Option::is_none")]
209-
pub presence_penalty: Option<f32>,
210-
/// An object specifying the format that the model must output.
211-
#[serde(skip_serializing_if = "Option::is_none")]
212-
pub response_format: Option<ResponseFormat>,
213-
/// If specified, our system will make a best effort to sample deterministically,
214-
/// such that repeated requests with the same seed and parameters should return the same result.
215-
#[serde(skip_serializing_if = "Option::is_none")]
216-
pub seed: Option<u64>,
217-
/// Specifies the latency tier to use for processing the request.
218-
#[serde(skip_serializing_if = "Option::is_none")]
219-
pub service_tier: Option<String>,
220-
/// Up to 4 sequences where the API will stop generating further tokens.
221-
#[serde(skip_serializing_if = "Option::is_none")]
222-
pub stop: Option<StopTokens>,
223-
/// If set, partial message deltas will be sent, like in ChatGPT.
224-
/// Tokens will be sent as data-only server-sent events as they become available,
225-
/// with the stream terminated by a data: [DONE] message.
226-
#[serde(default)]
186+
/// Detector config.
187+
pub detectors: DetectorConfig,
188+
/// Stream parameter.
227189
pub stream: bool,
228-
/// Options for streaming response. Only set this when you set stream: true.
229-
#[serde(skip_serializing_if = "Option::is_none")]
230-
pub stream_options: Option<StreamOptions>,
231-
/// What sampling temperature to use, between 0 and 2.
232-
/// Higher values like 0.8 will make the output more random,
233-
/// while lower values like 0.2 will make it more focused and deterministic.
234-
#[serde(skip_serializing_if = "Option::is_none")]
235-
pub temperature: Option<f32>,
236-
/// An alternative to sampling with temperature, called nucleus sampling,
237-
/// where the model considers the results of the tokens with top_p probability mass.
238-
/// So 0.1 means only the tokens comprising the top 10% probability mass are considered.
239-
#[serde(skip_serializing_if = "Option::is_none")]
240-
pub top_p: Option<f32>,
241-
/// A list of tools the model may call.
242-
#[serde(default, skip_serializing_if = "Vec::is_empty")]
243-
pub tools: Vec<Tool>,
244-
/// Controls which (if any) tool is called by the model.
245-
#[serde(skip_serializing_if = "Option::is_none")]
246-
pub tool_choice: Option<ToolChoice>,
247-
/// Whether to enable parallel function calling during tool use.
248-
#[serde(skip_serializing_if = "Option::is_none")]
249-
pub parallel_tool_calls: Option<bool>,
250-
/// A unique identifier representing your end-user.
251-
#[serde(skip_serializing_if = "Option::is_none")]
252-
pub user: Option<String>,
190+
/// Model name.
191+
pub model: String,
192+
/// Messages.
193+
pub messages: Vec<Message>,
194+
/// Inner request.
195+
pub inner: Map<String, Value>,
196+
}
253197

254-
// Additional vllm params
255-
#[serde(skip_serializing_if = "Option::is_none")]
256-
pub best_of: Option<usize>,
257-
#[serde(skip_serializing_if = "Option::is_none")]
258-
pub use_beam_search: Option<bool>,
259-
#[serde(skip_serializing_if = "Option::is_none")]
260-
pub top_k: Option<isize>,
261-
#[serde(skip_serializing_if = "Option::is_none")]
262-
pub min_p: Option<f32>,
263-
#[serde(skip_serializing_if = "Option::is_none")]
264-
pub repetition_penalty: Option<f32>,
265-
#[serde(skip_serializing_if = "Option::is_none")]
266-
pub length_penalty: Option<f32>,
267-
#[serde(skip_serializing_if = "Option::is_none")]
268-
pub early_stopping: Option<bool>,
269-
#[serde(skip_serializing_if = "Option::is_none")]
270-
pub ignore_eos: Option<bool>,
271-
#[serde(skip_serializing_if = "Option::is_none")]
272-
pub min_tokens: Option<u32>,
273-
#[serde(skip_serializing_if = "Option::is_none")]
274-
pub stop_token_ids: Option<Vec<usize>>,
275-
#[serde(skip_serializing_if = "Option::is_none")]
276-
pub skip_special_tokens: Option<bool>,
277-
#[serde(skip_serializing_if = "Option::is_none")]
278-
pub spaces_between_special_tokens: Option<bool>,
198+
impl TryFrom<Map<String, Value>> for ChatCompletionsRequest {
199+
type Error = ValidationError;
279200

280-
// Detectors
281-
// Note: We are making it optional, since this structure also gets used to
282-
// form request for chat completions. And downstream server, might choose to
283-
// reject extra parameters.
284-
#[serde(skip_serializing_if = "Option::is_none")]
285-
pub detectors: Option<DetectorConfig>,
201+
fn try_from(mut value: Map<String, Value>) -> Result<Self, Self::Error> {
202+
let detectors = if let Some(detectors) = value.remove("detectors") {
203+
DetectorConfig::deserialize(detectors)
204+
.map_err(|_| ValidationError::Invalid("error deserializing `detectors`".into()))?
205+
} else {
206+
DetectorConfig::default()
207+
};
208+
let stream = value
209+
.get("stream")
210+
.and_then(|v| v.as_bool())
211+
.unwrap_or_default();
212+
let model = if let Some(Value::String(model)) = value.get("model") {
213+
Ok(model.clone())
214+
} else {
215+
Err(ValidationError::Required("model".into()))
216+
}?;
217+
if model.is_empty() {
218+
return Err(ValidationError::Invalid("`model` must not be empty".into()));
219+
}
220+
let messages = if let Some(messages) = value.get("messages") {
221+
Vec::<Message>::deserialize(messages)
222+
.map_err(|_| ValidationError::Invalid("error deserializing `messages`".into()))
223+
} else {
224+
Err(ValidationError::Required("messages".into()))
225+
}?;
226+
if messages.is_empty() {
227+
return Err(ValidationError::Invalid(
228+
"`messages` must not be empty".into(),
229+
));
230+
}
231+
Ok(ChatCompletionsRequest {
232+
detectors,
233+
stream,
234+
model,
235+
messages,
236+
inner: value,
237+
})
238+
}
239+
}
240+
241+
impl Serialize for ChatCompletionsRequest {
242+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
243+
where
244+
S: serde::Serializer,
245+
{
246+
self.inner.serialize(serializer)
247+
}
286248
}
287249

288250
/// Structure to contain parameters for detectors.
@@ -291,7 +253,6 @@ pub struct ChatCompletionsRequest {
291253
pub struct DetectorConfig {
292254
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
293255
pub input: HashMap<String, DetectorParams>,
294-
295256
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
296257
pub output: HashMap<String, DetectorParams>,
297258
}
@@ -369,7 +330,7 @@ pub enum Role {
369330
Tool,
370331
}
371332

372-
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
333+
#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
373334
#[serde(deny_unknown_fields)]
374335
pub struct Message {
375336
/// The role of the author of this message.
@@ -731,3 +692,103 @@ impl OrchestratorWarning {
731692
}
732693
}
733694
}
695+
696+
#[cfg(test)]
697+
mod test {
698+
use serde_json::json;
699+
700+
use super::*;
701+
702+
#[test]
703+
fn test_chat_completions_request() -> Result<(), serde_json::Error> {
704+
// Test deserialize
705+
let detectors = DetectorConfig {
706+
input: HashMap::from([("some_detector".into(), DetectorParams::new())]),
707+
output: HashMap::new(),
708+
};
709+
let messages = vec![Message {
710+
content: Some(Content::Text("Hi there!".to_string())),
711+
..Default::default()
712+
}];
713+
let json_request = json!({
714+
"model": "test",
715+
"detectors": detectors,
716+
"messages": messages,
717+
});
718+
let request = ChatCompletionsRequest::deserialize(&json_request)?;
719+
let mut inner = json_request.as_object().unwrap().to_owned();
720+
inner.remove("detectors").unwrap();
721+
assert_eq!(
722+
request,
723+
ChatCompletionsRequest {
724+
detectors,
725+
stream: false,
726+
model: "test".into(),
727+
messages: messages.clone(),
728+
inner,
729+
}
730+
);
731+
732+
// Test deserialize with no detectors
733+
let json_request = json!({
734+
"model": "test",
735+
"messages": messages,
736+
});
737+
let request = ChatCompletionsRequest::deserialize(&json_request)?;
738+
let inner = json_request.as_object().unwrap().to_owned();
739+
assert_eq!(
740+
request,
741+
ChatCompletionsRequest {
742+
detectors: DetectorConfig::default(),
743+
stream: false,
744+
model: "test".into(),
745+
messages: messages.clone(),
746+
inner,
747+
}
748+
);
749+
750+
// Test deserialize validation errors
751+
let result = ChatCompletionsRequest::deserialize(json!({
752+
"detectors": DetectorConfig::default(),
753+
"messages": messages,
754+
}));
755+
assert!(result.is_err_and(|error| error.to_string() == "`model` is required"));
756+
757+
let result = ChatCompletionsRequest::deserialize(json!({
758+
"model": "",
759+
"detectors": DetectorConfig::default(),
760+
"messages": Vec::<Message>::default(),
761+
}));
762+
assert!(result.is_err_and(|error| error.to_string() == "`model` must not be empty"));
763+
764+
let result = ChatCompletionsRequest::deserialize(json!({
765+
"model": "test",
766+
"detectors": DetectorConfig::default(),
767+
"messages": Vec::<Message>::default(),
768+
}));
769+
assert!(result.is_err_and(|error| error.to_string() == "`messages` must not be empty"));
770+
771+
let result = ChatCompletionsRequest::deserialize(json!({
772+
"model": "test",
773+
"detectors": DetectorConfig::default(),
774+
"messages": ["invalid"],
775+
}));
776+
assert!(result.is_err_and(|error| error.to_string() == "error deserializing `messages`"));
777+
778+
// Test serialize
779+
let serialized_request = serde_json::to_value(request)?;
780+
assert_eq!(
781+
serialized_request,
782+
json!({
783+
"model": "test",
784+
"messages": [Message {
785+
content: Some(Content::Text("Hi there!".to_string())),
786+
role: Role::User,
787+
..Default::default()
788+
}],
789+
})
790+
);
791+
792+
Ok(())
793+
}
794+
}

src/orchestrator/common/client.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,8 @@ pub async fn detect_text_context(
247247
pub async fn chat_completion(
248248
client: &OpenAiClient,
249249
headers: HeaderMap,
250-
mut request: openai::ChatCompletionsRequest,
250+
request: openai::ChatCompletionsRequest,
251251
) -> Result<openai::ChatCompletionsResponse, Error> {
252-
request.stream = false;
253-
request.detectors = None;
254252
let model_id = request.model.clone();
255253
debug!(%model_id, ?request, "sending chat completions request");
256254
let response = client
@@ -269,10 +267,8 @@ pub async fn chat_completion(
269267
pub async fn chat_completion_stream(
270268
client: &OpenAiClient,
271269
headers: HeaderMap,
272-
mut request: openai::ChatCompletionsRequest,
270+
request: openai::ChatCompletionsRequest,
273271
) -> Result<ChatCompletionStream, Error> {
274-
request.stream = true;
275-
request.detectors = None;
276272
let model_id = request.model.clone();
277273
debug!(%model_id, ?request, "sending chat completions stream request");
278274
let response = client

src/orchestrator/handlers/chat_completions_detection/streaming.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub async fn handle_streaming(
3030
task: ChatCompletionsDetectionTask,
3131
) -> Result<ChatCompletionsResponse, Error> {
3232
let trace_id = task.trace_id;
33-
let detectors = task.request.detectors.clone().unwrap_or_default();
33+
let detectors = task.request.detectors.clone();
3434
info!(%trace_id, config = ?detectors, "task started");
3535
let _input_detectors = detectors.input;
3636
let _output_detectors = detectors.output;

src/orchestrator/handlers/chat_completions_detection/unary.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub async fn handle_unary(
3939
task: ChatCompletionsDetectionTask,
4040
) -> Result<ChatCompletionsResponse, Error> {
4141
let trace_id = task.trace_id;
42-
let detectors = task.request.detectors.clone().unwrap_or_default();
42+
let detectors = task.request.detectors.clone();
4343
info!(%trace_id, config = ?detectors, "task started");
4444
let input_detectors = detectors.input;
4545
let output_detectors = detectors.output;

0 commit comments

Comments
 (0)