Skip to content

Commit fdc6dea

Browse files
authored
Use builder (#75)
* use builder mode in Config Signed-off-by: kerthcet <kerthcet@gmail.com> * rename Signed-off-by: kerthcet <kerthcet@gmail.com> * add tests Signed-off-by: kerthcet <kerthcet@gmail.com> * add tests Signed-off-by: kerthcet <kerthcet@gmail.com> * Use value rather than ref for config Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 1cbab28 commit fdc6dea

11 files changed

Lines changed: 435 additions & 188 deletions

File tree

.env.test

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
AMRS_API_KEY=your_amrs_api_key_here
22
OPENAI_API_KEY=your_openai_api_key_here
3-
FOO_API_KEY=your_foo_api_key_here

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ edition = "2024"
66
[dependencies]
77
async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses",] }
88
async-trait = "0.1.89"
9+
derive_builder = "0.20.2"
910
dotenvy = "0.15.7"
1011
lazy_static = "1.5.0"
1112
rand = "0.9.2"

src/client/client.rs

Lines changed: 109 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,30 @@
11
use std::collections::HashMap;
22

3-
use crate::config::Config;
4-
use crate::config::ModelId;
3+
use crate::config::{Config, ModelConfig, ModelId, RoutingMode};
54
use crate::provider::provider;
65
use crate::router::router;
76

8-
// ------------------ Chat Role ------------------
9-
#[derive(Debug, Clone)]
10-
pub enum ChatRole {
11-
User,
12-
Assistant,
13-
System,
14-
}
15-
16-
// ------------------ Message ------------------
17-
#[derive(Debug, Clone)]
18-
pub struct TextMessage {
19-
pub role: ChatRole,
20-
pub content: String,
21-
}
22-
237
pub struct Client {
24-
config: Config,
258
router_tracker: Option<router::RouterTracker>,
269
router: Box<dyn router::Router>,
2710
providers: HashMap<ModelId, Box<dyn provider::Provider>>,
2811
}
2912

3013
impl Client {
3114
pub fn new(config: Config) -> Self {
32-
let mut cfg = config;
33-
cfg.finalize().expect("Invalid configuration");
15+
let mut cfg = config.clone();
16+
cfg.populate();
3417

3518
let providers = cfg
3619
.models
3720
.iter()
38-
.map(|m| {
39-
let provider = m
40-
.provider
41-
.as_ref()
42-
.expect("Model provider must be specified");
43-
44-
(m.id.clone(), provider::build_provider(provider, m))
45-
})
21+
.map(|m| (m.id.clone(), provider::construct_provider(m)))
4622
.collect();
4723

4824
Self {
49-
config: cfg.clone(),
5025
router_tracker: None,
5126
providers: providers,
52-
router: router::build_router(cfg.routing_mode, cfg.models),
27+
router: router::construct_router(cfg.routing_mode, cfg.models),
5328
}
5429
}
5530

@@ -64,7 +39,110 @@ impl Client {
6439
request: provider::ResponseRequest,
6540
) -> Result<provider::ResponseResult, provider::APIError> {
6641
let model_id = self.router.sample(&request);
67-
let provider = self.providers.get(model_id).unwrap();
42+
let provider = self.providers.get(&model_id).unwrap();
6843
provider.create_response(request).await
6944
}
7045
}
46+
47+
#[cfg(test)]
48+
mod tests {
49+
use super::*;
50+
#[test]
51+
fn test_client_new() {
52+
struct TestCase {
53+
name: &'static str,
54+
config: Config,
55+
expected_router_name: &'static str,
56+
enabled_tracker: bool,
57+
}
58+
59+
let cases = vec![
60+
TestCase {
61+
name: "basic config",
62+
config: Config::builder()
63+
.models(vec![
64+
ModelConfig::builder()
65+
.id("model_c".to_string())
66+
.build()
67+
.unwrap(),
68+
])
69+
.build()
70+
.unwrap(),
71+
expected_router_name: "RandomRouter",
72+
enabled_tracker: false,
73+
},
74+
TestCase {
75+
name: "weighted router",
76+
config: Config::builder()
77+
.routing_mode(RoutingMode::Weighted)
78+
.models(vec![
79+
crate::config::ModelConfig::builder()
80+
.id("model_a".to_string())
81+
.provider(Some("openai".to_string()))
82+
.base_url(Some("https://api.openai.com/v1".to_string()))
83+
.weight(1)
84+
.build()
85+
.unwrap(),
86+
crate::config::ModelConfig::builder()
87+
.id("model_b".to_string())
88+
.provider(Some("openai".to_string()))
89+
.base_url(Some("https://api.openai.com/v1".to_string()))
90+
.weight(3)
91+
.build()
92+
.unwrap(),
93+
])
94+
.build()
95+
.unwrap(),
96+
expected_router_name: "WeightedRouter",
97+
enabled_tracker: false,
98+
},
99+
TestCase {
100+
name: "router tracker enabled",
101+
config: Config::builder()
102+
.models(vec![
103+
ModelConfig::builder()
104+
.id("model_a".to_string())
105+
.provider(Some("openai".to_string()))
106+
.base_url(Some("https://api.openai.com/v1".to_string()))
107+
.build()
108+
.unwrap(),
109+
ModelConfig::builder()
110+
.id("model_b".to_string())
111+
.provider(Some("openai".to_string()))
112+
.base_url(Some("https://api.openai.com/v1".to_string()))
113+
.build()
114+
.unwrap(),
115+
])
116+
.build()
117+
.unwrap(),
118+
expected_router_name: "RandomRouter",
119+
enabled_tracker: true,
120+
},
121+
];
122+
123+
for case in cases {
124+
let mut client = Client::new(case.config.clone());
125+
if case.enabled_tracker {
126+
client.enable_router_tracker();
127+
}
128+
assert_eq!(
129+
client.router.name(),
130+
case.expected_router_name,
131+
"Test case '{}' failed",
132+
case.name
133+
);
134+
assert_eq!(
135+
client.router_tracker.is_some(),
136+
case.enabled_tracker,
137+
"Test case '{}' failed on router tracker state",
138+
case.name
139+
);
140+
assert_eq!(
141+
client.providers.len(),
142+
case.config.models.len(),
143+
"Test case '{}' failed on providers count",
144+
case.name
145+
);
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)