@@ -5,7 +5,7 @@ use derive_builder::Builder;
55use lazy_static:: lazy_static;
66
77// ------------------ Provider ------------------
8- const OPENAI_PROVIDER : & str = "OPENAI" ;
8+ pub const DEFAULT_PROVIDER : & str = "OPENAI" ;
99
1010lazy_static ! {
1111 pub static ref PROVIDER_BASE_URLS : HashMap <& ' static str , & ' static str > = {
@@ -14,7 +14,7 @@ lazy_static! {
1414 m. insert( "DEEPINFRA" , "https://api.deepinfra.com/v1/openai" ) ;
1515 m. insert( "OPENROUTER" , "https://openrouter.ai/api/v1" ) ;
1616
17- m. insert( "FAKE " , "http://localhost:8080" ) ; // test only
17+ m. insert( "FAKER " , "http://localhost:8080" ) ; // test only
1818 // TODO: support more providers here...
1919 m
2020 } ;
@@ -82,9 +82,9 @@ impl ModelConfig {
8282#[ builder( build_fn( validate = "Self::validate" ) , pattern = "mutable" ) ]
8383pub struct Config {
8484 // global configs for models, will be overridden by model-specific configs
85- #[ builder( default = "https://api.openai.com/v1" . to_string ( ) ) ]
86- pub ( crate ) base_url : String ,
87- #[ builder( default = "OPENAI_PROVIDER .to_string()" , setter( custom) ) ]
85+ #[ builder( default = None , setter ( custom ) ) ]
86+ pub ( crate ) base_url : Option < String > ,
87+ #[ builder( default = "DEFAULT_PROVIDER .to_string()" , setter( custom) ) ]
8888 pub ( crate ) provider : String ,
8989 #[ builder( default = "0.8" ) ]
9090 pub ( crate ) temperature : f32 ,
@@ -104,22 +104,34 @@ impl Config {
104104
105105 // populate will fill in the missing model-specific configs with global configs.
106106 pub fn populate ( & mut self ) -> & mut Self {
107+ let global_base_url = match self . base_url . is_some ( ) {
108+ true => self . base_url . clone ( ) ,
109+ false => Some (
110+ PROVIDER_BASE_URLS
111+ . get ( self . provider . as_str ( ) )
112+ . unwrap ( )
113+ . to_string ( ) ,
114+ ) ,
115+ } ;
116+
107117 for model in & mut self . models {
108- let model_url_exist = model. base_url . is_some ( ) ;
118+ if model. base_url . is_none ( ) {
119+ if model. provider . is_some ( ) {
120+ model. base_url = Some (
121+ PROVIDER_BASE_URLS
122+ . get ( model. provider . as_ref ( ) . unwrap ( ) . as_str ( ) )
123+ . unwrap ( )
124+ . to_string ( ) ,
125+ ) ;
126+ } else {
127+ model. base_url = global_base_url. clone ( ) ;
128+ }
129+ }
109130
110131 if model. provider . is_none ( ) {
111132 model. provider = Some ( self . provider . clone ( ) ) ;
112133 }
113134
114- if !model_url_exist
115- && PROVIDER_BASE_URLS . contains_key ( model. provider . as_ref ( ) . unwrap ( ) . as_str ( ) )
116- {
117- model. base_url =
118- Some ( PROVIDER_BASE_URLS [ model. provider . as_ref ( ) . unwrap ( ) . as_str ( ) ] . to_string ( ) ) ;
119- }
120- if !model_url_exist {
121- model. base_url = Some ( self . base_url . clone ( ) ) ;
122- }
123135 if model. temperature . is_none ( ) {
124136 model. temperature = Some ( self . temperature ) ;
125137 }
@@ -132,6 +144,10 @@ impl Config {
132144}
133145
134146impl ConfigBuilder {
147+ pub fn base_url < S : AsRef < str > > ( & mut self , url : S ) -> & mut Self {
148+ self . base_url = Some ( Some ( url. as_ref ( ) . to_string ( ) ) ) ;
149+ self
150+ }
135151 pub fn model ( & mut self , model : ModelConfig ) -> & mut Self {
136152 let mut models = self . models . clone ( ) . unwrap_or_default ( ) ;
137153 models. push ( model) ;
@@ -180,7 +196,7 @@ impl ConfigBuilder {
180196
181197 // check the existence of API key in environment variables
182198 if let Some ( provider) = & model. provider {
183- let env_var = format ! ( "{}_API_KEY" , provider. to_uppercase ( ) ) ;
199+ let env_var = format ! ( "{}_API_KEY" , provider) ;
184200 if env:: var ( & env_var) . is_err ( ) {
185201 return Err ( format ! (
186202 "API key for provider '{}' not found in environment variable '{}'" ,
@@ -194,16 +210,15 @@ impl ConfigBuilder {
194210 "{}_API_KEY" ,
195211 self . provider
196212 . as_ref( )
197- . unwrap_or( & OPENAI_PROVIDER . to_string( ) )
213+ . unwrap_or( & DEFAULT_PROVIDER . to_string( ) )
198214 . to_uppercase( )
199215 ) ;
200216 if env:: var ( & env_var) . is_err ( ) {
201217 return Err ( format ! (
202218 "API key for provider '{}' not found in environment variable '{}'" ,
203219 self . provider
204220 . as_ref( )
205- . unwrap_or( & OPENAI_PROVIDER . to_string( ) )
206- . to_uppercase( ) ,
221+ . unwrap_or( & DEFAULT_PROVIDER . to_string( ) ) ,
207222 env_var
208223 ) ) ;
209224 }
@@ -234,10 +249,8 @@ mod tests {
234249 )
235250 . build ( ) ;
236251 assert ! ( valid_simplest_models_cfg. is_ok( ) ) ;
237- assert ! ( valid_simplest_models_cfg. as_ref( ) . unwrap( ) . provider == OPENAI_PROVIDER ) ;
238- assert ! (
239- valid_simplest_models_cfg. as_ref( ) . unwrap( ) . base_url == "https://api.openai.com/v1"
240- ) ;
252+ assert ! ( valid_simplest_models_cfg. as_ref( ) . unwrap( ) . provider == DEFAULT_PROVIDER ) ;
253+ assert ! ( valid_simplest_models_cfg. as_ref( ) . unwrap( ) . base_url == None ) ;
241254 assert ! ( valid_simplest_models_cfg. as_ref( ) . unwrap( ) . temperature == 0.8 ) ;
242255 assert ! (
243256 valid_simplest_models_cfg
@@ -285,7 +298,7 @@ mod tests {
285298 // case 4:
286299 // AMRS_API_KEY is set in .env.test already.
287300 let valid_cfg_with_customized_provider = Config :: builder ( )
288- . base_url ( "http://example.ai" . to_string ( ) )
301+ . base_url ( "http://example.ai" )
289302 . max_output_tokens ( 2048 )
290303 . model (
291304 ModelConfig :: builder ( )
0 commit comments