Skip to content

Commit 4769102

Browse files
committed
add support for WRR algo for sampling
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 7013f49 commit 4769102

7 files changed

Lines changed: 158 additions & 46 deletions

File tree

src/client/client.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use std::collections::HashMap;
22

3-
use crate::config::{Config, ModelConfig, ModelId, RoutingMode};
3+
use crate::config::{Config, ModelId};
44
use crate::provider::provider;
55
use crate::router::router;
66

77
pub struct Client {
88
providers: HashMap<ModelId, Box<dyn provider::Provider>>,
99
router: Box<dyn router::Router>,
10-
router_stats: router::RouterStats,
1110
}
1211

1312
impl Client {
@@ -24,12 +23,11 @@ impl Client {
2423
Self {
2524
providers: providers,
2625
router: router::construct_router(cfg.routing_mode, cfg.models),
27-
router_stats: router::RouterStats::default(),
2826
}
2927
}
3028

3129
pub async fn create_response(
32-
&self,
30+
&mut self,
3331
request: provider::ResponseRequest,
3432
) -> Result<provider::ResponseResult, provider::APIError> {
3533
let model_id = self.router.sample(&request);
@@ -41,6 +39,8 @@ impl Client {
4139
#[cfg(test)]
4240
mod tests {
4341
use super::*;
42+
use crate::config::{Config, ModelConfig, RoutingMode};
43+
4444
#[test]
4545
fn test_client_new() {
4646
struct TestCase {
@@ -64,9 +64,9 @@ mod tests {
6464
expected_router_name: "RandomRouter",
6565
},
6666
TestCase {
67-
name: "weighted router",
67+
name: "weighted round-robin router",
6868
config: Config::builder()
69-
.routing_mode(RoutingMode::Weighted)
69+
.routing_mode(RoutingMode::WRR)
7070
.models(vec![
7171
crate::config::ModelConfig::builder()
7272
.id("model_a".to_string())

src/config.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ lazy_static! {
2323
#[derive(Debug, Clone, PartialEq)]
2424
pub enum RoutingMode {
2525
Random,
26-
Weighted,
26+
WRR, // WeightedRoundRobin,
2727
}
2828

2929
// ------------------ Model Config ------------------
@@ -131,7 +131,7 @@ impl ConfigBuilder {
131131

132132
for model in self.models.as_ref().unwrap() {
133133
if self.routing_mode.is_some()
134-
&& self.routing_mode.as_ref().unwrap() == &RoutingMode::Weighted
134+
&& self.routing_mode.as_ref().unwrap() == &RoutingMode::WRR
135135
&& model.weight <= 0
136136
{
137137
return Err(format!(

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod router {
22
mod random;
33
pub mod router;
4+
pub mod stats;
45
mod weight;
56
}
67
mod client {

src/router/random.rs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ use rand::Rng;
22

33
use crate::config::ModelId;
44
use crate::provider::provider::ResponseRequest;
5-
use crate::router::router::Router;
5+
use crate::router::router::{ModelInfo, Router};
66

77
pub struct RandomRouter {
8-
pub model_ids: Vec<ModelId>,
8+
pub model_infos: Vec<ModelInfo>,
99
}
1010

1111
impl RandomRouter {
12-
pub fn new(model_ids: Vec<ModelId>) -> Self {
13-
Self { model_ids }
12+
pub fn new(model_infos: Vec<ModelInfo>) -> Self {
13+
Self { model_infos }
1414
}
1515
}
1616

@@ -19,10 +19,10 @@ impl Router for RandomRouter {
1919
"RandomRouter"
2020
}
2121

22-
fn sample(&self, _input: &ResponseRequest) -> ModelId {
22+
fn sample(&mut self, _input: &ResponseRequest) -> ModelId {
2323
let mut rng = rand::rng();
24-
let idx = rng.random_range(0..self.model_ids.len());
25-
self.model_ids[idx].clone()
24+
let idx = rng.random_range(0..self.model_infos.len());
25+
self.model_infos[idx].id.clone()
2626
}
2727
}
2828

@@ -32,14 +32,28 @@ mod tests {
3232

3333
#[test]
3434
fn test_random_router_sampling() {
35-
let model_ids = vec!["model_a".to_string(), "model_b".to_string()];
36-
let router = RandomRouter::new(model_ids.clone());
35+
let model_infos = vec![
36+
ModelInfo {
37+
id: "model_x".to_string(),
38+
weight: 1,
39+
},
40+
ModelInfo {
41+
id: "model_y".to_string(),
42+
weight: 2,
43+
},
44+
ModelInfo {
45+
id: "model_z".to_string(),
46+
weight: 3,
47+
},
48+
];
49+
let mut router = RandomRouter::new(model_infos.clone());
3750
let mut counts = std::collections::HashMap::new();
51+
3852
for _ in 0..1000 {
3953
let sampled_id = router.sample(&ResponseRequest::default());
4054
*counts.entry(sampled_id.clone()).or_insert(0) += 1;
4155
}
42-
assert!(counts.len() == model_ids.len());
56+
assert!(counts.len() == model_infos.len());
4357
for count in counts.values() {
4458
assert!(*count > 0);
4559
}

src/router/router.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,31 @@ use std::sync::atomic::AtomicUsize;
44
use crate::config::{ModelConfig, ModelId, RoutingMode};
55
use crate::provider::provider::ResponseRequest;
66
use crate::router::random::RandomRouter;
7-
use crate::router::weight::WeightedRouter;
7+
use crate::router::weight::WeightedRoundRobinRouter;
8+
9+
#[derive(Debug, Clone)]
10+
pub struct ModelInfo {
11+
pub id: ModelId,
12+
pub weight: i32,
13+
}
814

915
pub fn construct_router(mode: RoutingMode, models: Vec<ModelConfig>) -> Box<dyn Router> {
10-
let model_ids: Vec<ModelId> = models.iter().map(|m| m.id.clone()).collect();
16+
let model_infos: Vec<ModelInfo> = models
17+
.iter()
18+
.map(|m| ModelInfo {
19+
id: m.id.clone(),
20+
weight: m.weight.clone(),
21+
})
22+
.collect();
1123
match mode {
12-
RoutingMode::Random => Box::new(RandomRouter::new(model_ids)),
13-
RoutingMode::Weighted => Box::new(WeightedRouter::new(model_ids)),
24+
RoutingMode::Random => Box::new(RandomRouter::new(model_infos)),
25+
RoutingMode::WRR => Box::new(WeightedRoundRobinRouter::new(model_infos)),
1426
}
1527
}
1628

1729
pub trait Router {
1830
fn name(&self) -> &'static str;
19-
fn sample(&self, input: &ResponseRequest) -> ModelId;
20-
}
21-
22-
pub struct RouterStats {
23-
total_requests: HashMap<ModelId, AtomicUsize>,
24-
}
25-
26-
impl RouterStats {
27-
pub fn default() -> Self {
28-
RouterStats {
29-
total_requests: HashMap::new(),
30-
}
31-
}
31+
fn sample(&mut self, input: &ResponseRequest) -> ModelId;
3232
}
3333

3434
#[cfg(test)]
@@ -54,7 +54,7 @@ mod tests {
5454
let random_router = construct_router(RoutingMode::Random, model_configs.clone());
5555
assert_eq!(random_router.name(), "RandomRouter");
5656

57-
let weighted_router = construct_router(RoutingMode::Weighted, model_configs.clone());
57+
let weighted_router = construct_router(RoutingMode::WRR, model_configs.clone());
5858
assert_eq!(weighted_router.name(), "WeightedRouter");
5959
}
6060
}

src/router/stats.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use std::collections::HashMap;
2+
use std::sync::atomic::{AtomicUsize, Ordering};
3+
4+
use crate::config::ModelId;
5+
6+
pub struct RouterStats {
7+
requests_per_model: HashMap<ModelId, AtomicUsize>,
8+
}
9+
10+
impl RouterStats {
11+
pub fn default() -> Self {
12+
RouterStats {
13+
requests_per_model: HashMap::new(),
14+
}
15+
}
16+
17+
pub fn increment_request(&mut self, model_id: &ModelId) -> usize {
18+
let counter = self
19+
.requests_per_model
20+
.entry(model_id.clone())
21+
.or_insert_with(|| AtomicUsize::new(0));
22+
counter.fetch_add(1, Ordering::Relaxed)
23+
}
24+
}

src/router/weight.rs

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,96 @@
1-
use super::router::Router;
1+
use crate::router::router::{ModelInfo, Router};
22
use crate::{config::ModelId, provider::provider::ResponseRequest};
33

4-
pub struct WeightedRouter {
5-
pub model_ids: Vec<ModelId>,
4+
pub struct WeightedRoundRobinRouter {
5+
total_weight: i32,
6+
model_infos: Vec<ModelInfo>,
7+
// current_weight is ordered by model_infos index.
8+
current_weights: Vec<i32>,
69
}
710

8-
impl WeightedRouter {
9-
pub fn new(model_ids: Vec<ModelId>) -> Self {
10-
Self { model_ids }
11+
impl WeightedRoundRobinRouter {
12+
pub fn new(model_infos: Vec<ModelInfo>) -> Self {
13+
let total_weight = model_infos.iter().map(|m| m.weight).sum();
14+
let length = model_infos.len();
15+
16+
Self {
17+
model_infos: model_infos,
18+
total_weight: total_weight,
19+
current_weights: vec![0; length],
20+
}
1121
}
1222
}
1323

14-
impl Router for WeightedRouter {
24+
impl Router for WeightedRoundRobinRouter {
1525
fn name(&self) -> &'static str {
1626
"WeightedRouter"
1727
}
1828

19-
fn sample(&self, _input: &ResponseRequest) -> ModelId {
20-
// TODO: Implement weighted sampling logic
21-
return self.model_ids[0].clone();
29+
// Use Smooth Weighted Round Robin Algorithm.
30+
fn sample(&mut self, _input: &ResponseRequest) -> ModelId {
31+
// return early if only one model.
32+
if self.model_infos.len() == 1 {
33+
return self.model_infos[0].id.clone();
34+
}
35+
36+
self.current_weights
37+
.iter_mut()
38+
.enumerate()
39+
.for_each(|(i, weight)| {
40+
*weight += self.model_infos[i].weight;
41+
});
42+
43+
let mut max_index = 0;
44+
for i in 1..self.current_weights.len() {
45+
if self.current_weights[i] > self.current_weights[max_index] {
46+
max_index = i;
47+
}
48+
}
49+
50+
self.current_weights[max_index] -= self.total_weight;
51+
self.model_infos[max_index].id.clone()
52+
}
53+
}
54+
55+
#[cfg(test)]
56+
mod tests {
57+
use super::*;
58+
use std::collections::HashMap;
59+
60+
#[test]
61+
fn test_weighted_round_robin_sampling() {
62+
let model_infos = vec![
63+
ModelInfo {
64+
id: "model_x".to_string(),
65+
weight: 1,
66+
},
67+
ModelInfo {
68+
id: "model_y".to_string(),
69+
weight: 3,
70+
},
71+
ModelInfo {
72+
id: "model_z".to_string(),
73+
weight: 6,
74+
},
75+
];
76+
let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone());
77+
let mut counts = HashMap::new();
78+
for _ in 0..1000 {
79+
let sampled_id = wrr.sample(&ResponseRequest::default());
80+
*counts.entry(sampled_id.clone()).or_insert(0) += 1;
81+
}
82+
assert!(counts.len() == model_infos.len());
83+
// Check approximate distribution.
84+
let total_counts: usize = counts.values().sum();
85+
assert!(total_counts == 1000);
86+
let model_x_counts = *counts.get("model_x").unwrap_or(&0);
87+
let model_y_counts = *counts.get("model_y").unwrap_or(&0);
88+
let model_z_counts = *counts.get("model_z").unwrap_or(&0);
89+
let model_x_ratio = model_x_counts as f64 / total_counts as f64;
90+
let model_y_ratio = model_y_counts as f64 / total_counts as f64;
91+
let model_z_ratio = model_z_counts as f64 / total_counts as f64;
92+
assert!((model_x_ratio - 0.1).abs() < 0.05);
93+
assert!((model_y_ratio - 0.3).abs() < 0.05);
94+
assert!((model_z_ratio - 0.6).abs() < 0.05);
2295
}
2396
}

0 commit comments

Comments
 (0)