11use std:: collections:: HashMap ;
22
3- use crate :: config:: Config ;
4- use crate :: config:: ModelId ;
3+ use crate :: config:: { Config , ModelConfig , ModelId , RoutingMode } ;
54use crate :: provider:: provider;
65use crate :: router:: router;
76
8- // ------------------ Chat Role ------------------
9- #[ derive( Debug , Clone ) ]
10- pub enum ChatRole {
11- User ,
12- Assistant ,
13- System ,
14- }
15-
16- // ------------------ Message ------------------
17- #[ derive( Debug , Clone ) ]
18- pub struct TextMessage {
19- pub role : ChatRole ,
20- pub content : String ,
21- }
22-
237pub struct Client {
24- config : Config ,
258 router_tracker : Option < router:: RouterTracker > ,
269 router : Box < dyn router:: Router > ,
2710 providers : HashMap < ModelId , Box < dyn provider:: Provider > > ,
2811}
2912
3013impl Client {
3114 pub fn new ( config : Config ) -> Self {
32- let mut cfg = config;
33- cfg. finalize ( ) . expect ( "Invalid configuration" ) ;
15+ let mut cfg = config. clone ( ) ;
16+ cfg. populate ( ) ;
3417
3518 let providers = cfg
3619 . models
3720 . iter ( )
38- . map ( |m| {
39- let provider = m
40- . provider
41- . as_ref ( )
42- . expect ( "Model provider must be specified" ) ;
43-
44- ( m. id . clone ( ) , provider:: build_provider ( provider, m) )
45- } )
21+ . map ( |m| ( m. id . clone ( ) , provider:: construct_provider ( m) ) )
4622 . collect ( ) ;
4723
4824 Self {
49- config : cfg. clone ( ) ,
5025 router_tracker : None ,
5126 providers : providers,
52- router : router:: build_router ( cfg. routing_mode , cfg. models ) ,
27+ router : router:: construct_router ( cfg. routing_mode , cfg. models ) ,
5328 }
5429 }
5530
@@ -64,7 +39,110 @@ impl Client {
6439 request : provider:: ResponseRequest ,
6540 ) -> Result < provider:: ResponseResult , provider:: APIError > {
6641 let model_id = self . router . sample ( & request) ;
67- let provider = self . providers . get ( model_id) . unwrap ( ) ;
42+ let provider = self . providers . get ( & model_id) . unwrap ( ) ;
6843 provider. create_response ( request) . await
6944 }
7045}
46+
47+ #[ cfg( test) ]
48+ mod tests {
49+ use super :: * ;
50+ #[ test]
51+ fn test_client_new ( ) {
52+ struct TestCase {
53+ name : & ' static str ,
54+ config : Config ,
55+ expected_router_name : & ' static str ,
56+ enabled_tracker : bool ,
57+ }
58+
59+ let cases = vec ! [
60+ TestCase {
61+ name: "basic config" ,
62+ config: Config :: builder( )
63+ . models( vec![
64+ ModelConfig :: builder( )
65+ . id( "model_c" . to_string( ) )
66+ . build( )
67+ . unwrap( ) ,
68+ ] )
69+ . build( )
70+ . unwrap( ) ,
71+ expected_router_name: "RandomRouter" ,
72+ enabled_tracker: false ,
73+ } ,
74+ TestCase {
75+ name: "weighted router" ,
76+ config: Config :: builder( )
77+ . routing_mode( RoutingMode :: Weighted )
78+ . models( vec![
79+ crate :: config:: ModelConfig :: builder( )
80+ . id( "model_a" . to_string( ) )
81+ . provider( Some ( "openai" . to_string( ) ) )
82+ . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
83+ . weight( 1 )
84+ . build( )
85+ . unwrap( ) ,
86+ crate :: config:: ModelConfig :: builder( )
87+ . id( "model_b" . to_string( ) )
88+ . provider( Some ( "openai" . to_string( ) ) )
89+ . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
90+ . weight( 3 )
91+ . build( )
92+ . unwrap( ) ,
93+ ] )
94+ . build( )
95+ . unwrap( ) ,
96+ expected_router_name: "WeightedRouter" ,
97+ enabled_tracker: false ,
98+ } ,
99+ TestCase {
100+ name: "router tracker enabled" ,
101+ config: Config :: builder( )
102+ . models( vec![
103+ ModelConfig :: builder( )
104+ . id( "model_a" . to_string( ) )
105+ . provider( Some ( "openai" . to_string( ) ) )
106+ . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
107+ . build( )
108+ . unwrap( ) ,
109+ ModelConfig :: builder( )
110+ . id( "model_b" . to_string( ) )
111+ . provider( Some ( "openai" . to_string( ) ) )
112+ . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
113+ . build( )
114+ . unwrap( ) ,
115+ ] )
116+ . build( )
117+ . unwrap( ) ,
118+ expected_router_name: "RandomRouter" ,
119+ enabled_tracker: true ,
120+ } ,
121+ ] ;
122+
123+ for case in cases {
124+ let mut client = Client :: new ( case. config . clone ( ) ) ;
125+ if case. enabled_tracker {
126+ client. enable_router_tracker ( ) ;
127+ }
128+ assert_eq ! (
129+ client. router. name( ) ,
130+ case. expected_router_name,
131+ "Test case '{}' failed" ,
132+ case. name
133+ ) ;
134+ assert_eq ! (
135+ client. router_tracker. is_some( ) ,
136+ case. enabled_tracker,
137+ "Test case '{}' failed on router tracker state" ,
138+ case. name
139+ ) ;
140+ assert_eq ! (
141+ client. providers. len( ) ,
142+ case. config. models. len( ) ,
143+ "Test case '{}' failed on providers count" ,
144+ case. name
145+ ) ;
146+ }
147+ }
148+ }
0 commit comments