11use std:: collections:: HashMap ;
22
3- use crate :: config:: { Config , ModelId } ;
3+ use crate :: client :: config:: { Config , ModelName } ;
44use crate :: provider:: provider;
55use crate :: router:: router;
6+ use crate :: types:: error:: OpenAIError ;
7+ use crate :: types:: responses:: { CreateResponse , Response } ;
68
79pub struct Client {
8- providers : HashMap < ModelId , Box < dyn provider:: Provider > > ,
10+ providers : HashMap < ModelName , Box < dyn provider:: Provider > > ,
911 router : Box < dyn router:: Router > ,
1012}
1113
@@ -17,7 +19,7 @@ impl Client {
1719 let providers = cfg
1820 . models
1921 . iter ( )
20- . map ( |m| ( m. id . clone ( ) , provider:: construct_provider ( m. clone ( ) ) ) )
22+ . map ( |m| ( m. name . clone ( ) , provider:: construct_provider ( m. clone ( ) ) ) )
2123 . collect ( ) ;
2224
2325 Self {
@@ -28,18 +30,18 @@ impl Client {
2830
2931 pub async fn create_response (
3032 & mut self ,
31- request : provider :: CreateResponseReq ,
32- ) -> Result < provider :: CreateResponseRes , provider :: APIError > {
33- let model_id = self . router . sample ( & request) ;
34- let provider = self . providers . get ( & model_id ) . unwrap ( ) ;
33+ request : CreateResponse ,
34+ ) -> Result < Response , OpenAIError > {
35+ let candidate = self . router . sample ( & request) ;
36+ let provider = self . providers . get ( & candidate ) . unwrap ( ) ;
3537 provider. create_response ( request) . await
3638 }
3739}
3840
3941#[ cfg( test) ]
4042mod tests {
4143 use super :: * ;
42- use crate :: config:: { Config , ModelConfig , RoutingMode } ;
44+ use crate :: client :: config:: { Config , ModelConfig , RoutingMode } ;
4345 use dotenvy:: from_filename;
4446
4547 #[ test]
@@ -58,7 +60,7 @@ mod tests {
5860 config: Config :: builder( )
5961 . models( vec![
6062 ModelConfig :: builder( )
61- . id ( "model_c" . to_string( ) )
63+ . name ( "model_c" . to_string( ) )
6264 . build( )
6365 . unwrap( ) ,
6466 ] )
@@ -71,15 +73,15 @@ mod tests {
7173 config: Config :: builder( )
7274 . routing_mode( RoutingMode :: WRR )
7375 . models( vec![
74- crate :: config:: ModelConfig :: builder( )
75- . id ( "model_a" . to_string( ) )
76+ crate :: client :: config:: ModelConfig :: builder( )
77+ . name ( "model_a" . to_string( ) )
7678 . provider( Some ( "openai" . to_string( ) ) )
7779 . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
7880 . weight( 1 )
7981 . build( )
8082 . unwrap( ) ,
81- crate :: config:: ModelConfig :: builder( )
82- . id ( "model_b" . to_string( ) )
83+ crate :: client :: config:: ModelConfig :: builder( )
84+ . name ( "model_b" . to_string( ) )
8385 . provider( Some ( "openai" . to_string( ) ) )
8486 . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
8587 . weight( 3 )
@@ -95,13 +97,13 @@ mod tests {
9597 config: Config :: builder( )
9698 . models( vec![
9799 ModelConfig :: builder( )
98- . id ( "model_a" . to_string( ) )
100+ . name ( "model_a" . to_string( ) )
99101 . provider( Some ( "openai" . to_string( ) ) )
100102 . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
101103 . build( )
102104 . unwrap( ) ,
103105 ModelConfig :: builder( )
104- . id ( "model_b" . to_string( ) )
106+ . name ( "model_b" . to_string( ) )
105107 . provider( Some ( "openai" . to_string( ) ) )
106108 . base_url( Some ( "https://api.openai.com/v1" . to_string( ) ) )
107109 . build( )
0 commit comments