diff --git a/Cargo.lock b/Cargo.lock index bc41907f..1dcefbc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -524,6 +524,26 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.8.0" @@ -660,6 +680,7 @@ dependencies = [ "axum-test", "bytes", "clap", + "dashmap", "eventsource-stream", "futures", "futures-util", @@ -909,6 +930,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.2" @@ -1382,7 +1409,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 27d62511..216bb559 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ axum = { version = "0.8.1", features = ["json"] } axum-extra = { version = "0.10.0", features = ["json-lines"] } bytes = "1.10.0" clap = { version = "4.5.26", features = ["derive", "env"] } +dashmap = "6.1.0" eventsource-stream = "0.2.3" futures = "0.3.31" futures-util = { version = "0.3", default-features = false, features = [] } diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 19050e56..3362cc05 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -644,7 +644,7 @@ pub struct ChatCompletionChunk { } /// Streaming chat completion chunk choice. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ChatCompletionChunkChoice { /// The index of the choice in the list of choices. pub index: u32, @@ -659,7 +659,7 @@ pub struct ChatCompletionChunkChoice { } /// Streaming chat completion delta. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Default, Clone, Serialize, Deserialize)] pub struct ChatCompletionDelta { /// The role of the author of this message. #[serde(skip_serializing_if = "Option::is_none")] diff --git a/src/orchestrator/handlers/chat_completions_detection/streaming.rs b/src/orchestrator/handlers/chat_completions_detection/streaming.rs index 74ebcb59..2bd0f1a7 100644 --- a/src/orchestrator/handlers/chat_completions_detection/streaming.rs +++ b/src/orchestrator/handlers/chat_completions_detection/streaming.rs @@ -14,26 +14,32 @@ limitations under the License. */ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; +use dashmap::DashMap; +use futures::StreamExt; +use opentelemetry::trace::TraceId; use tokio::sync::mpsc; -use tracing::{Instrument, info}; +use tracing::{Instrument, error, info, instrument}; +use uuid::Uuid; use super::ChatCompletionsDetectionTask; use crate::{ clients::openai::*, - orchestrator::{Context, Error}, + config::DetectorType, + models::{DetectionWarningReason, DetectorParams, UNSUITABLE_INPUT_MESSAGE}, + orchestrator::{ + common::{self, validate_detectors}, types::{ChatCompletionBatcher, ChatCompletionStream, ChatMessageIterator, ChoiceIndex, Chunk, DetectionBatchStream, DetectionStream, Detections}, Context, Error + }, }; pub async fn handle_streaming( - _ctx: Arc, + ctx: Arc, task: ChatCompletionsDetectionTask, ) -> Result { let trace_id = task.trace_id; let detectors = task.request.detectors.clone(); info!(%trace_id, config = ?detectors, "task started"); - let _input_detectors = detectors.input; - let _output_detectors = detectors.output; // Create response channel let (response_tx, response_rx) = @@ -41,15 +47,436 @@ pub async fn handle_streaming( tokio::spawn( async move { - // TODO - let _ = response_tx - .send(Err(Error::Validation( - "streaming is not yet supported".into(), - ))) + let input_detectors = detectors.input; + let output_detectors = detectors.output; + + // Validate input detectors + if let Err(error) = validate_detectors( + &input_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + ) { + let _ = response_tx.send(Err(error)).await; + return; + } + // Validate output detectors + if let Err(error) = validate_detectors( + &output_detectors, + &ctx.config.detectors, + &[DetectorType::TextContents], + true, + ) { + let _ = response_tx.send(Err(error)).await; + return; + } + + // Handle input detection (unary) + if !input_detectors.is_empty() { + match handle_input_detection(ctx.clone(), &task, input_detectors).await { + Ok(Some(chunk)) => { + info!(%trace_id, "task completed: returning response with input detections"); + // Send message with input detections to response channel and terminate + let _ = response_tx.send(Ok(Some(chunk))).await; + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; + return; + } + Ok(None) => (), // No input detections + Err(error) => { + // Input detections failed + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + + // Create chat completions stream + let client = ctx + .clients + .get_as::("chat_generation") + .unwrap(); + let chat_completion_stream = + match common::chat_completion_stream(client, task.headers.clone(), task.request.clone()).await { + Ok(stream) => stream, + Err(error) => { + error!(%trace_id, %error, "task failed: error creating chat completions stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + }; + + if output_detectors.is_empty() { + // No output detectors, forward chat completion chunks to response channel + forward_chat_completion_stream(trace_id, chat_completion_stream, response_tx.clone()).await; + } else { + // Partition output detectors + // Detectors using whole_doc_chunker are processed at the end after all chat completion chunks + // have been collected. Results are returned with the second-last message. + let (whole_doc_output_detectors, output_detectors): (HashMap<_, _>, HashMap<_, _>) = output_detectors + .into_iter() + .partition(|(detector_id, _params)| { + let chunker_id = ctx + .config + .get_chunker_id(detector_id) + .unwrap(); + chunker_id == "whole_doc_chunker" + }); + + // Create chat completions state + // This holds all chat completion chunks received and is used to build responses + let chat_completion_state: Arc>> = Arc::new(DashMap::new()); + + // Handle output detection + handle_output_detection( + ctx.clone(), + &task, + output_detectors, + chat_completion_state.clone(), + chat_completion_stream, + response_tx.clone(), + ) + .await; + + // Handle whole doc output detection + handle_whole_doc_output_detection( + ctx.clone(), + &task, + whole_doc_output_detectors, + chat_completion_state, + response_tx.clone(), + ) .await; + } + + // Send None to signal completion + let _ = response_tx.send(Ok(None)).await; } .in_current_span(), ); Ok(ChatCompletionsResponse::Streaming(response_rx)) } + +#[instrument(skip_all)] +async fn handle_input_detection( + ctx: Arc, + task: &ChatCompletionsDetectionTask, + detectors: HashMap, +) -> Result, Error> { + let trace_id = task.trace_id; + let model_id = task.request.model.clone(); + + // Input detectors are only applied to the last message + // Get the last message + let messages = task.request.messages(); + let message = if let Some(message) = messages.last() { + message + } else { + return Err(Error::Validation("No messages provided".into())); + }; + // Validate role + if !matches!( + message.role, + Some(Role::User) | Some(Role::Assistant) | Some(Role::System) + ) { + return Err(Error::Validation( + "Last message role must be user, assistant, or system".into(), + )); + } + let input_id = message.index; + let input_text = message.text.map(|s| s.to_string()).unwrap_or_default(); + let detections = match common::text_contents_detections( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + input_id, + vec![(0, input_text)], + ) + .await + { + Ok((_, detections)) => detections, + Err(error) => { + error!(%trace_id, %error, "task failed: error processing input detections"); + return Err(error); + } + }; + if !detections.is_empty() { + // Build chat completion chunk with input detections + let chunk = ChatCompletionChunk { + id: Uuid::new_v4().simple().to_string(), + model: model_id, + created: common::current_timestamp().as_secs() as i64, + detections: Some(ChatDetections { + input: vec![InputDetectionResult { + message_index: message.index, + results: detections.into(), + }], + ..Default::default() + }), + warnings: vec![OrchestratorWarning::new( + DetectionWarningReason::UnsuitableInput, + UNSUITABLE_INPUT_MESSAGE, + )], + ..Default::default() + }; + Ok(Some(chunk)) + } else { + // No input detections + Ok(None) + } +} + + +#[instrument(skip_all)] +async fn handle_output_detection( + ctx: Arc, + task: &ChatCompletionsDetectionTask, + detectors: HashMap, + chat_completion_state: Arc>>, + chat_completion_stream: ChatCompletionStream, + response_tx: mpsc::Sender, Error>>, +) { + let trace_id = &task.trace_id; + let request = task.request.clone(); + // n represents how many choices to generate for each input message (default=1) + let n = request.extra.get("n").and_then(|v| v.as_i64()).unwrap_or(1) as usize; + + // Create input channels + // As choices are processed independently, each choice_index has it's own input channels. + let mut input_txs = HashMap::with_capacity(n); + let mut input_rxs = HashMap::with_capacity(n); + (0..n).for_each(|choice_index| { + let (input_tx, input_rx) = mpsc::channel::>(32); + input_txs.insert(choice_index as u32, input_tx); + input_rxs.insert(choice_index as u32, input_rx); + }); + + // Create detection streams + // As choices are processed independently, each choice_index has it's own detection streams. + let mut detection_streams = Vec::with_capacity(n * detectors.len()); + for (choice_index, input_rx) in input_rxs { + match common::text_contents_detection_streams( + ctx.clone(), + task.headers.clone(), + detectors.clone(), + choice_index, + input_rx, + ) + .await + { + Ok(streams) => { + detection_streams.extend(streams); + } + Err(error) => { + error!(%trace_id, %error, "task failed: error creating detection streams"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + } + } + } + + // Spawn task to consume chat completions stream, send text to input channels, and update chat completion state + tokio::spawn(process_chat_completion_stream( + chat_completion_state.clone(), + chat_completion_stream, + input_txs, + )); + + // Process detection streams + if detection_streams.len() == 1 { + // Process single detection stream, batching not applicable + let detection_stream = detection_streams.swap_remove(1); + process_detection_stream(trace_id, chat_completion_state, detection_stream, response_tx).await; + } else { + // Create detection batch stream + let detection_batch_stream = + DetectionBatchStream::new(ChatCompletionBatcher::new(detectors.len()), detection_streams); + process_detection_batch_stream( + trace_id, + chat_completion_state, + detection_batch_stream, + response_tx, + ) + .await; + } +} + +#[instrument(skip_all)] +async fn handle_whole_doc_output_detection( + _ctx: Arc, + _task: &ChatCompletionsDetectionTask, + _detectors: HashMap, + _chat_completion_state: Arc>>, + _response_tx: mpsc::Sender, Error>>, +) { + todo!() +} + +async fn forward_chat_completion_stream( + trace_id: TraceId, + mut chat_completion_stream: ChatCompletionStream, + response_tx: mpsc::Sender, Error>>, +) { + while let Some((_index, result)) = chat_completion_stream.next().await { + match result { + Ok(Some(chat_completion)) => { + // Send message to response channel + if response_tx.send(Ok(Some(chat_completion))).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Ok(None) => { + // Send message to response channel + if response_tx.send(Ok(None)).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from chat completion stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: chat completion stream closed"); +} + +/// Consumes chat completion stream, sends choices to input channels, +/// and updates shared chat completions. +#[allow(clippy::type_complexity)] +async fn process_chat_completion_stream( + chat_completion_state: Arc>>, + mut chat_completion_stream: ChatCompletionStream, + input_txs: HashMap>>, +) { + while let Some((message_index, result)) = chat_completion_stream.next().await { + match result { + Ok(Some(chat_completion)) => { + // Send choice text to input channel + let choice = &chat_completion.choices[0]; // TODO: handle + let choice_index = choice.index; + let choice_text = choice.delta.content.clone().unwrap_or_default(); + let input_tx = input_txs.get(&choice_index).unwrap(); + let _ = input_tx.send(Ok((message_index, choice_text))).await; + + // Update chat completion state + match chat_completion_state.entry(choice_index) { + dashmap::Entry::Occupied(mut entry) => { + entry.get_mut().push(chat_completion); + }, + dashmap::Entry::Vacant(entry) => { + entry.insert(vec![chat_completion]); + }, + } + } + Ok(None) => (), + Err(error) => { + // Send error to all input channels + for input_tx in input_txs.values() { + let _ = input_tx.send(Err(error.clone())).await; + } + } + } + } +} + +/// Builds a response with output detections. +fn output_detection_response( + chat_completion_state: &Arc>>, + choice_index: u32, + chunk: Chunk, + detections: Detections, +) -> Result { + // Get chat completions for this choice index + let chat_completions = chat_completion_state + .get(&choice_index) + .unwrap(); + // let chat_completions = chat_completions + // .get(chunk.input_start_index..=chunk.input_end_index) + // .unwrap_or_default(); + // let mut chat_completion = chat_completions.last().cloned().unwrap(); + let mut chat_completion = chat_completions + .get(chunk.input_end_index) + .cloned() + .unwrap_or_default(); + chat_completion.choices = vec![ChatCompletionChunkChoice { + index: choice_index, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + content: Some(chunk.text), + ..Default::default() + }, + ..Default::default() + }]; + chat_completion.detections = Some(ChatDetections { + output: vec![OutputDetectionResult { + choice_index, + results: detections.into(), + }], + ..Default::default() + }); + //chat_completion.warnings = todo!(); + Ok(chat_completion) +} + +/// Consumes a detection stream, builds responses, and sends them to a response channel. +async fn process_detection_stream( + trace_id: &TraceId, + chat_completion_state: Arc>>, + mut detection_stream: DetectionStream, + response_tx: mpsc::Sender, Error>>, +) { + while let Some(result) = detection_stream.next().await { + match result { + Ok((choice_index, _detector_id, chunk, detections)) => { + let chat_completion = output_detection_response(&chat_completion_state, choice_index, chunk, detections).unwrap(); + // Send message to response channel + if response_tx.send(Ok(Some(chat_completion))).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection stream closed"); +} + +/// Consumes a detection batch stream, builds responses, and sends them to a response channel. +async fn process_detection_batch_stream( + trace_id: &TraceId, + chat_completion_state: Arc>>, + mut detection_batch_stream: DetectionBatchStream, + response_tx: mpsc::Sender, Error>>, +) { + while let Some(result) = detection_batch_stream.next().await { + match result { + Ok((chunk, choice_index, detections)) => { + let chat_completion = output_detection_response(&chat_completion_state, choice_index, chunk, detections).unwrap(); + // Send message to response channel + if response_tx.send(Ok(Some(chat_completion))).await.is_err() { + info!(%trace_id, "task completed: client disconnected"); + return; + } + } + Err(error) => { + error!(%trace_id, %error, "task failed: error received from detection batch stream"); + // Send error to response channel and terminate + let _ = response_tx.send(Err(error)).await; + return; + } + } + } + info!(%trace_id, "task completed: detection batch stream closed"); +} \ No newline at end of file