Skip to content

Commit 57fa04a

Browse files
committed
Cleaner version.
1 parent 1932c5b commit 57fa04a

File tree

4 files changed

+156
-77
lines changed

4 files changed

+156
-77
lines changed

Cargo.lock

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kvrouter/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ futures = "0.3.31"
1212
futures-util = "0.3.31"
1313
hyper = { version = "1.5.2", features = ["full"] }
1414
hyper-util = { version = "0.1.10", features = ["full"] }
15+
log = "0.4.25"
1516
rand = "0.9.0"
1617
slotmap = "1.0.7"
1718
tokio = { version = "1.43.0", features = ["macros", "rt-multi-thread"] }

kvrouter/src/lib.rs

+121-40
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ use axum::{
55
response::{IntoResponse, Response},
66
};
77
use futures_util::stream::StreamExt;
8+
use hyper::StatusCode;
89
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
910
use rand::{rng, Rng};
1011
use std::sync::atomic::{AtomicUsize, Ordering};
11-
use std::sync::{Arc, Mutex};
12+
use tokio::sync::{mpsc, oneshot};
1213

1314
mod trie;
1415

@@ -17,9 +18,8 @@ use crate::trie::Trie;
1718
const FACTOR_KEY: &str = "TGI_KVROUTER_FACTOR";
1819
type Client = hyper_util::client::legacy::Client<HttpConnector, Body>;
1920

20-
#[derive(Clone)]
2121
pub struct ContentAware {
22-
trie: Arc<Mutex<Trie>>,
22+
trie: Trie,
2323
}
2424

2525
impl Default for ContentAware {
@@ -30,14 +30,14 @@ impl Default for ContentAware {
3030

3131
impl ContentAware {
3232
pub fn new() -> Self {
33-
let trie = Arc::new(Mutex::new(Trie::new()));
33+
let trie = Trie::new();
3434
Self { trie }
3535
}
3636
}
3737

3838
impl LoadBalancer for ContentAware {
3939
fn next(&mut self, key: &[u8], n_backends: usize) -> usize {
40-
let mut trie = self.trie.lock().unwrap();
40+
let trie = &mut self.trie;
4141
let (start, stop) = trie.insert(key);
4242
let n = trie.count();
4343
eprintln!(
@@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware {
6060
}
6161
}
6262

63-
#[derive(Clone)]
6463
pub struct RoundRobin {
65-
current: Arc<AtomicUsize>,
64+
current: AtomicUsize,
6665
}
6766

6867
impl Default for RoundRobin {
@@ -73,7 +72,7 @@ impl Default for RoundRobin {
7372

7473
impl RoundRobin {
7574
pub fn new() -> Self {
76-
let current = Arc::new(AtomicUsize::new(0));
75+
let current = AtomicUsize::new(0);
7776
Self { current }
7877
}
7978
}
@@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin {
8483
}
8584
}
8685

87-
#[derive(Clone)]
8886
pub struct OverloadHandler<T: LoadBalancer> {
89-
client: Client,
9087
load_balancer: T,
91-
backends: Arc<Vec<String>>,
92-
inqueue: Arc<Vec<AtomicUsize>>,
93-
inflight: Arc<Vec<AtomicUsize>>,
88+
backends: Vec<String>,
89+
inqueue: Vec<AtomicUsize>,
90+
inflight: Vec<AtomicUsize>,
9491
factor: f32,
92+
rx: Rcv,
9593
}
9694

9795
impl<T: LoadBalancer> OverloadHandler<T> {
98-
pub fn new(load_balancer: T, backends: Vec<String>) -> Self {
99-
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
100-
.build(HttpConnector::new());
101-
let inflight = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
102-
let inqueue = Arc::new(backends.iter().map(|_| AtomicUsize::new(0)).collect());
96+
pub fn new(load_balancer: T, backends: Vec<String>, rx: Rcv) -> Self {
97+
let inflight = backends.iter().map(|_| AtomicUsize::new(0)).collect();
98+
let inqueue = backends.iter().map(|_| AtomicUsize::new(0)).collect();
10399
let factor: f32 = std::env::var(FACTOR_KEY)
104100
.unwrap_or("1.5".to_string())
105101
.parse()
106102
.unwrap_or(1.5);
107-
let backends = Arc::new(backends);
108103
Self {
109104
load_balancer,
110105
backends,
111-
client,
112106
factor,
113107
inflight,
114108
inqueue,
109+
rx,
115110
}
116111
}
117112

118-
fn next(&mut self, key: &[u8]) -> usize {
113+
fn next(&mut self, key: &[u8]) -> String {
119114
// Get the backend URL
120115
let index = self.load_balancer.next(key, self.backends.len());
121116
let n = self.backends.len();
@@ -138,29 +133,117 @@ impl<T: LoadBalancer> OverloadHandler<T> {
138133
inflight = self.inflight[index].load(Ordering::Relaxed);
139134
inqueue = self.inflight[index].load(Ordering::Relaxed);
140135
}
141-
index
136+
let backend = &self.backends[index];
137+
self.inflight[index].fetch_add(1, Ordering::Relaxed);
138+
self.inqueue[index].fetch_add(1, Ordering::Relaxed);
139+
backend.to_string()
140+
}
141+
142+
pub async fn run(&mut self) {
143+
while let Some(msg) = self.rx.recv().await {
144+
eprintln!("Msg {msg:?}");
145+
match msg {
146+
Msg::Next(key, sx) => {
147+
let backend: String = self.next(&key);
148+
eprintln!("Sending back backend {backend}");
149+
if let Err(err) = sx.send(backend) {
150+
eprintln!("Cannot send back result: {err}");
151+
}
152+
}
153+
Msg::Dequeue(backend) => {
154+
let index = self.backends.iter().position(|b| b == &backend);
155+
if let Some(index) = index {
156+
self.inqueue[index].fetch_sub(1, Ordering::Relaxed);
157+
}
158+
}
159+
Msg::Deflight(backend) => {
160+
let index = self.backends.iter().position(|b| b == &backend);
161+
if let Some(index) = index {
162+
self.inflight[index].fetch_sub(1, Ordering::Relaxed);
163+
}
164+
}
165+
Msg::AddBackend(backend) => {
166+
self.backends.push(backend);
167+
self.backends.sort();
168+
}
169+
Msg::RemoveBackend(backend) => {
170+
self.backends.retain(|b| *b == backend);
171+
self.backends.sort();
172+
}
173+
}
174+
}
142175
}
143176
}
144177

145178
pub trait LoadBalancer {
146179
fn next(&mut self, key: &[u8], n_backends: usize) -> usize;
147180
}
148181

149-
pub async fn handler<T: LoadBalancer>(
150-
State(mut state): State<OverloadHandler<T>>,
182+
#[derive(Debug)]
183+
pub enum Msg {
184+
Next(Vec<u8>, oneshot::Sender<String>),
185+
Dequeue(String),
186+
Deflight(String),
187+
AddBackend(String),
188+
RemoveBackend(String),
189+
}
190+
191+
type Snd = mpsc::Sender<Msg>;
192+
type Rcv = mpsc::Receiver<Msg>;
193+
194+
#[derive(Clone)]
195+
pub struct Communicator {
196+
sender: Snd,
197+
client: Client,
198+
}
199+
200+
impl Communicator {
201+
pub fn new(sender: Snd) -> Self {
202+
let client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
203+
.build(HttpConnector::new());
204+
Self { sender, client }
205+
}
206+
207+
async fn dequeue(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
208+
self.sender.send(Msg::Dequeue(backend)).await
209+
}
210+
211+
async fn deflight(&self, backend: String) -> Result<(), mpsc::error::SendError<Msg>> {
212+
self.sender.send(Msg::Deflight(backend)).await
213+
}
214+
215+
async fn next(&self, key: Vec<u8>) -> Result<String, mpsc::error::SendError<Msg>> {
216+
let (sx, rx) = oneshot::channel();
217+
self.sender.send(Msg::Next(key, sx)).await?;
218+
let backend = rx.await.unwrap();
219+
Ok(backend)
220+
}
221+
}
222+
223+
pub async fn handler(
224+
State(state): State<Communicator>,
151225
req: Request,
152-
) -> Response<Body> {
226+
) -> Result<Response<Body>, StatusCode> {
153227
// Get the next backend index
154-
let limit = 1024 * 1024;
155228
let (parts, body) = req.into_parts();
156-
// TODO
157-
let bytes = axum::body::to_bytes(body, limit).await.unwrap();
158-
let index = state.next(&bytes);
159-
let backend = &state.backends[index];
160-
state.inflight[index].fetch_add(1, Ordering::Relaxed);
161-
state.inqueue[index].fetch_add(1, Ordering::Relaxed);
162-
163-
let body: Body = bytes.into();
229+
let mut response_stream = body.into_data_stream();
230+
let event = response_stream.next().await;
231+
let key = if let Some(Ok(event)) = &event {
232+
event.to_vec()
233+
} else {
234+
vec![]
235+
};
236+
let backend = state.next(key).await.map_err(|_| StatusCode::BAD_GATEWAY)?;
237+
let response_stream = async_stream::stream! {
238+
let mut response_stream = Box::pin(response_stream);
239+
if let Some(event) = event{
240+
yield event;
241+
}
242+
while let Some(raw_event) = response_stream.next().await {
243+
yield raw_event;
244+
}
245+
};
246+
let body = Body::from_stream(response_stream);
164247
let mut req = Request::from_parts(parts, body);
165248
let path = req.uri().path();
166249
let path_query = req
@@ -177,9 +260,7 @@ pub async fn handler<T: LoadBalancer>(
177260
.client
178261
.request(req)
179262
.await
180-
// TODO
181-
.unwrap();
182-
//.map_err(|_| StatusCode::BAD_GATEWAY)?;
263+
.map_err(|_| StatusCode::BAD_GATEWAY)?;
183264
let response = response.into_response();
184265
let (parts, body) = response.into_parts();
185266
let response_stream = body.into_data_stream();
@@ -190,16 +271,16 @@ pub async fn handler<T: LoadBalancer>(
190271
if start{
191272
eprintln!("Not inqueue");
192273

193-
state.inqueue[index].fetch_sub(1, Ordering::Relaxed);
274+
state.dequeue(backend.to_string()).await.unwrap();
194275
start = false;
195276
}
196277
yield raw_event;
197278
}
198279
eprintln!("Not inflight");
199-
state.inflight[index].fetch_sub(1, Ordering::Relaxed);
280+
state.deflight(backend.to_string()).await.unwrap();
200281
};
201282

202283
let body = Body::from_stream(response_stream);
203284

204-
Response::from_parts(parts, body)
285+
Ok(Response::from_parts(parts, body))
205286
}

kvrouter/src/main.rs

+31-35
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,46 @@ use axum::{
22
routing::Router,
33
routing::{get, post},
44
};
5-
use kvrouter::{handler, ContentAware, OverloadHandler, RoundRobin};
5+
use kvrouter::{handler, Communicator, ContentAware, OverloadHandler, RoundRobin};
66

77
#[tokio::main]
88
async fn main() {
99
// List of backend servers
1010
let backends = vec![
1111
"http://localhost:8000".to_string(),
12-
"http://localhost:8001".to_string(),
13-
"http://localhost:8002".to_string(),
14-
"http://localhost:8003".to_string(),
12+
// "http://localhost:8001".to_string(),
13+
// "http://localhost:8002".to_string(),
14+
// "http://localhost:8003".to_string(),
1515
];
1616

1717
// Create a new instance of the RoundRobinRouter
18-
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
19-
println!("Using round robin");
20-
let lb = RoundRobin::new();
21-
// Create the Axum router
22-
let router = OverloadHandler::new(lb, backends);
23-
let app = Router::new()
24-
.route("/{*key}", get(handler))
25-
.route("/{*key}", post(handler))
26-
.with_state(router);
2718

28-
// run it
29-
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
30-
.await
31-
.unwrap();
32-
println!("listening on {}", listener.local_addr().unwrap());
33-
axum::serve(listener, app).await.unwrap();
34-
} else {
35-
println!("Using Content aware");
36-
let lb = ContentAware::new();
37-
// Create the Axum router
38-
let router = OverloadHandler::new(lb, backends);
39-
let app = Router::new()
40-
.route("/{*key}", get(handler))
41-
.route("/{*key}", post(handler))
42-
.with_state(router);
19+
println!("Using Content aware");
20+
// Create the Axum router
4321

44-
// run it
45-
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
46-
.await
47-
.unwrap();
48-
println!("listening on {}", listener.local_addr().unwrap());
49-
axum::serve(listener, app).await.unwrap();
50-
};
22+
let (sx, rx) = tokio::sync::mpsc::channel(100);
23+
let communicator = Communicator::new(sx);
24+
tokio::task::spawn(async move {
25+
if std::env::var("TGI_KVROUTER_LB").unwrap_or("".to_string()) == *"roundrobin" {
26+
println!("Using round robin");
27+
let lb = RoundRobin::new();
28+
let mut router = OverloadHandler::new(lb, backends, rx);
29+
router.run().await;
30+
} else {
31+
let lb = ContentAware::new();
32+
let mut router = OverloadHandler::new(lb, backends, rx);
33+
router.run().await;
34+
};
35+
});
36+
let app = Router::new()
37+
.route("/{*key}", get(handler))
38+
.route("/{*key}", post(handler))
39+
.with_state(communicator);
40+
41+
// run it
42+
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
43+
.await
44+
.unwrap();
45+
println!("listening on {}", listener.local_addr().unwrap());
46+
axum::serve(listener, app).await.unwrap();
5147
}

0 commit comments

Comments
 (0)