@@ -5,10 +5,11 @@ use axum::{
5
5
response:: { IntoResponse , Response } ,
6
6
} ;
7
7
use futures_util:: stream:: StreamExt ;
8
+ use hyper:: StatusCode ;
8
9
use hyper_util:: { client:: legacy:: connect:: HttpConnector , rt:: TokioExecutor } ;
9
10
use rand:: { rng, Rng } ;
10
11
use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
11
- use std :: sync:: { Arc , Mutex } ;
12
+ use tokio :: sync:: { mpsc , oneshot } ;
12
13
13
14
mod trie;
14
15
@@ -17,9 +18,8 @@ use crate::trie::Trie;
17
18
const FACTOR_KEY : & str = "TGI_KVROUTER_FACTOR" ;
18
19
type Client = hyper_util:: client:: legacy:: Client < HttpConnector , Body > ;
19
20
20
- #[ derive( Clone ) ]
21
21
pub struct ContentAware {
22
- trie : Arc < Mutex < Trie > > ,
22
+ trie : Trie ,
23
23
}
24
24
25
25
impl Default for ContentAware {
@@ -30,14 +30,14 @@ impl Default for ContentAware {
30
30
31
31
impl ContentAware {
32
32
pub fn new ( ) -> Self {
33
- let trie = Arc :: new ( Mutex :: new ( Trie :: new ( ) ) ) ;
33
+ let trie = Trie :: new ( ) ;
34
34
Self { trie }
35
35
}
36
36
}
37
37
38
38
impl LoadBalancer for ContentAware {
39
39
fn next ( & mut self , key : & [ u8 ] , n_backends : usize ) -> usize {
40
- let mut trie = self . trie . lock ( ) . unwrap ( ) ;
40
+ let trie = & mut self . trie ;
41
41
let ( start, stop) = trie. insert ( key) ;
42
42
let n = trie. count ( ) ;
43
43
eprintln ! (
@@ -60,9 +60,8 @@ impl LoadBalancer for ContentAware {
60
60
}
61
61
}
62
62
63
- #[ derive( Clone ) ]
64
63
pub struct RoundRobin {
65
- current : Arc < AtomicUsize > ,
64
+ current : AtomicUsize ,
66
65
}
67
66
68
67
impl Default for RoundRobin {
@@ -73,7 +72,7 @@ impl Default for RoundRobin {
73
72
74
73
impl RoundRobin {
75
74
pub fn new ( ) -> Self {
76
- let current = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
75
+ let current = AtomicUsize :: new ( 0 ) ;
77
76
Self { current }
78
77
}
79
78
}
@@ -84,38 +83,34 @@ impl LoadBalancer for RoundRobin {
84
83
}
85
84
}
86
85
87
- #[ derive( Clone ) ]
88
86
pub struct OverloadHandler < T : LoadBalancer > {
89
- client : Client ,
90
87
load_balancer : T ,
91
- backends : Arc < Vec < String > > ,
92
- inqueue : Arc < Vec < AtomicUsize > > ,
93
- inflight : Arc < Vec < AtomicUsize > > ,
88
+ backends : Vec < String > ,
89
+ inqueue : Vec < AtomicUsize > ,
90
+ inflight : Vec < AtomicUsize > ,
94
91
factor : f32 ,
92
+ rx : Rcv ,
95
93
}
96
94
97
95
impl < T : LoadBalancer > OverloadHandler < T > {
98
- pub fn new ( load_balancer : T , backends : Vec < String > ) -> Self {
99
- let client = hyper_util:: client:: legacy:: Client :: < ( ) , ( ) > :: builder ( TokioExecutor :: new ( ) )
100
- . build ( HttpConnector :: new ( ) ) ;
101
- let inflight = Arc :: new ( backends. iter ( ) . map ( |_| AtomicUsize :: new ( 0 ) ) . collect ( ) ) ;
102
- let inqueue = Arc :: new ( backends. iter ( ) . map ( |_| AtomicUsize :: new ( 0 ) ) . collect ( ) ) ;
96
+ pub fn new ( load_balancer : T , backends : Vec < String > , rx : Rcv ) -> Self {
97
+ let inflight = backends. iter ( ) . map ( |_| AtomicUsize :: new ( 0 ) ) . collect ( ) ;
98
+ let inqueue = backends. iter ( ) . map ( |_| AtomicUsize :: new ( 0 ) ) . collect ( ) ;
103
99
let factor: f32 = std:: env:: var ( FACTOR_KEY )
104
100
. unwrap_or ( "1.5" . to_string ( ) )
105
101
. parse ( )
106
102
. unwrap_or ( 1.5 ) ;
107
- let backends = Arc :: new ( backends) ;
108
103
Self {
109
104
load_balancer,
110
105
backends,
111
- client,
112
106
factor,
113
107
inflight,
114
108
inqueue,
109
+ rx,
115
110
}
116
111
}
117
112
118
- fn next ( & mut self , key : & [ u8 ] ) -> usize {
113
+ fn next ( & mut self , key : & [ u8 ] ) -> String {
119
114
// Get the backend URL
120
115
let index = self . load_balancer . next ( key, self . backends . len ( ) ) ;
121
116
let n = self . backends . len ( ) ;
@@ -138,29 +133,117 @@ impl<T: LoadBalancer> OverloadHandler<T> {
138
133
inflight = self . inflight [ index] . load ( Ordering :: Relaxed ) ;
139
134
inqueue = self . inflight [ index] . load ( Ordering :: Relaxed ) ;
140
135
}
141
- index
136
+ let backend = & self . backends [ index] ;
137
+ self . inflight [ index] . fetch_add ( 1 , Ordering :: Relaxed ) ;
138
+ self . inqueue [ index] . fetch_add ( 1 , Ordering :: Relaxed ) ;
139
+ backend. to_string ( )
140
+ }
141
+
142
+ pub async fn run ( & mut self ) {
143
+ while let Some ( msg) = self . rx . recv ( ) . await {
144
+ eprintln ! ( "Msg {msg:?}" ) ;
145
+ match msg {
146
+ Msg :: Next ( key, sx) => {
147
+ let backend: String = self . next ( & key) ;
148
+ eprintln ! ( "Sending back backend {backend}" ) ;
149
+ if let Err ( err) = sx. send ( backend) {
150
+ eprintln ! ( "Cannot send back result: {err}" ) ;
151
+ }
152
+ }
153
+ Msg :: Dequeue ( backend) => {
154
+ let index = self . backends . iter ( ) . position ( |b| b == & backend) ;
155
+ if let Some ( index) = index {
156
+ self . inqueue [ index] . fetch_sub ( 1 , Ordering :: Relaxed ) ;
157
+ }
158
+ }
159
+ Msg :: Deflight ( backend) => {
160
+ let index = self . backends . iter ( ) . position ( |b| b == & backend) ;
161
+ if let Some ( index) = index {
162
+ self . inflight [ index] . fetch_sub ( 1 , Ordering :: Relaxed ) ;
163
+ }
164
+ }
165
+ Msg :: AddBackend ( backend) => {
166
+ self . backends . push ( backend) ;
167
+ self . backends . sort ( ) ;
168
+ }
169
+ Msg :: RemoveBackend ( backend) => {
170
+ self . backends . retain ( |b| * b == backend) ;
171
+ self . backends . sort ( ) ;
172
+ }
173
+ }
174
+ }
142
175
}
143
176
}
144
177
145
178
pub trait LoadBalancer {
146
179
fn next ( & mut self , key : & [ u8 ] , n_backends : usize ) -> usize ;
147
180
}
148
181
149
- pub async fn handler < T : LoadBalancer > (
150
- State ( mut state) : State < OverloadHandler < T > > ,
182
+ #[ derive( Debug ) ]
183
+ pub enum Msg {
184
+ Next ( Vec < u8 > , oneshot:: Sender < String > ) ,
185
+ Dequeue ( String ) ,
186
+ Deflight ( String ) ,
187
+ AddBackend ( String ) ,
188
+ RemoveBackend ( String ) ,
189
+ }
190
+
191
+ type Snd = mpsc:: Sender < Msg > ;
192
+ type Rcv = mpsc:: Receiver < Msg > ;
193
+
194
+ #[ derive( Clone ) ]
195
+ pub struct Communicator {
196
+ sender : Snd ,
197
+ client : Client ,
198
+ }
199
+
200
+ impl Communicator {
201
+ pub fn new ( sender : Snd ) -> Self {
202
+ let client = hyper_util:: client:: legacy:: Client :: < ( ) , ( ) > :: builder ( TokioExecutor :: new ( ) )
203
+ . build ( HttpConnector :: new ( ) ) ;
204
+ Self { sender, client }
205
+ }
206
+
207
+ async fn dequeue ( & self , backend : String ) -> Result < ( ) , mpsc:: error:: SendError < Msg > > {
208
+ self . sender . send ( Msg :: Dequeue ( backend) ) . await
209
+ }
210
+
211
+ async fn deflight ( & self , backend : String ) -> Result < ( ) , mpsc:: error:: SendError < Msg > > {
212
+ self . sender . send ( Msg :: Deflight ( backend) ) . await
213
+ }
214
+
215
+ async fn next ( & self , key : Vec < u8 > ) -> Result < String , mpsc:: error:: SendError < Msg > > {
216
+ let ( sx, rx) = oneshot:: channel ( ) ;
217
+ self . sender . send ( Msg :: Next ( key, sx) ) . await ?;
218
+ let backend = rx. await . unwrap ( ) ;
219
+ Ok ( backend)
220
+ }
221
+ }
222
+
223
+ pub async fn handler (
224
+ State ( state) : State < Communicator > ,
151
225
req : Request ,
152
- ) -> Response < Body > {
226
+ ) -> Result < Response < Body > , StatusCode > {
153
227
// Get the next backend index
154
- let limit = 1024 * 1024 ;
155
228
let ( parts, body) = req. into_parts ( ) ;
156
- // TODO
157
- let bytes = axum:: body:: to_bytes ( body, limit) . await . unwrap ( ) ;
158
- let index = state. next ( & bytes) ;
159
- let backend = & state. backends [ index] ;
160
- state. inflight [ index] . fetch_add ( 1 , Ordering :: Relaxed ) ;
161
- state. inqueue [ index] . fetch_add ( 1 , Ordering :: Relaxed ) ;
162
-
163
- let body: Body = bytes. into ( ) ;
229
+ let mut response_stream = body. into_data_stream ( ) ;
230
+ let event = response_stream. next ( ) . await ;
231
+ let key = if let Some ( Ok ( event) ) = & event {
232
+ event. to_vec ( )
233
+ } else {
234
+ vec ! [ ]
235
+ } ;
236
+ let backend = state. next ( key) . await . map_err ( |_| StatusCode :: BAD_GATEWAY ) ?;
237
+ let response_stream = async_stream:: stream! {
238
+ let mut response_stream = Box :: pin( response_stream) ;
239
+ if let Some ( event) = event{
240
+ yield event;
241
+ }
242
+ while let Some ( raw_event) = response_stream. next( ) . await {
243
+ yield raw_event;
244
+ }
245
+ } ;
246
+ let body = Body :: from_stream ( response_stream) ;
164
247
let mut req = Request :: from_parts ( parts, body) ;
165
248
let path = req. uri ( ) . path ( ) ;
166
249
let path_query = req
@@ -177,9 +260,7 @@ pub async fn handler<T: LoadBalancer>(
177
260
. client
178
261
. request ( req)
179
262
. await
180
- // TODO
181
- . unwrap ( ) ;
182
- //.map_err(|_| StatusCode::BAD_GATEWAY)?;
263
+ . map_err ( |_| StatusCode :: BAD_GATEWAY ) ?;
183
264
let response = response. into_response ( ) ;
184
265
let ( parts, body) = response. into_parts ( ) ;
185
266
let response_stream = body. into_data_stream ( ) ;
@@ -190,16 +271,16 @@ pub async fn handler<T: LoadBalancer>(
190
271
if start{
191
272
eprintln!( "Not inqueue" ) ;
192
273
193
- state. inqueue [ index ] . fetch_sub ( 1 , Ordering :: Relaxed ) ;
274
+ state. dequeue ( backend . to_string ( ) ) . await . unwrap ( ) ;
194
275
start = false ;
195
276
}
196
277
yield raw_event;
197
278
}
198
279
eprintln!( "Not inflight" ) ;
199
- state. inflight [ index ] . fetch_sub ( 1 , Ordering :: Relaxed ) ;
280
+ state. deflight ( backend . to_string ( ) ) . await . unwrap ( ) ;
200
281
} ;
201
282
202
283
let body = Body :: from_stream ( response_stream) ;
203
284
204
- Response :: from_parts ( parts, body)
285
+ Ok ( Response :: from_parts ( parts, body) )
205
286
}
0 commit comments