1+ use crate :: models:: ModelInfo ;
12use anyhow:: { Context , Result } ;
23use serde:: { Deserialize , Serialize } ;
34use std:: cell:: RefCell ;
5+ use std:: collections:: HashMap ;
46use std:: fs;
57use std:: path:: PathBuf ;
68
@@ -9,6 +11,10 @@ pub struct Config {
911 pub server : ServerConfig ,
1012 pub upload : UploadConfig ,
1113 pub formatting : FormattingConfig ,
14+ #[ serde( default ) ]
15+ pub models : HashMap < String , ModelInfo > ,
16+ #[ serde( default ) ]
17+ pub aliases : HashMap < String , String > ,
1218}
1319
1420#[ derive( Debug , Serialize , Deserialize , Clone ) ]
@@ -52,6 +58,8 @@ impl Default for Config {
5258 locale : "en" . to_string ( ) ,
5359 decimal_places : 2 ,
5460 } ,
61+ models : HashMap :: new ( ) ,
62+ aliases : HashMap :: new ( ) ,
5563 }
5664 }
5765}
@@ -171,6 +179,12 @@ pub fn show_config() -> Result<()> {
171179 println ! ( " Number Human: {}" , config. formatting. number_human) ;
172180 println ! ( " Locale: {}" , config. formatting. locale) ;
173181 println ! ( " Decimal Places: {}" , config. formatting. decimal_places) ;
182+ if !config. models . is_empty ( ) {
183+ println ! ( " Custom Models: {}" , config. models. len( ) ) ;
184+ }
185+ if !config. aliases . is_empty ( ) {
186+ println ! ( " Custom Aliases: {}" , config. aliases. len( ) ) ;
187+ }
174188 }
175189 None => {
176190 println ! ( "❌ No configuration file found." ) ;
@@ -226,6 +240,7 @@ pub fn set_config_value(key: &str, value: &str) -> Result<()> {
226240#[ cfg( test) ]
227241mod tests {
228242 use super :: * ;
243+ use crate :: models:: PricingStructure ;
229244 use tempfile:: TempDir ;
230245
231246 fn setup_test_config ( ) -> ( TempDir , PathBuf ) {
@@ -235,6 +250,54 @@ mod tests {
235250 ( dir, config_path)
236251 }
237252
253+ #[ test]
254+ fn test_config_with_custom_models ( ) {
255+ let toml_str = r#"
256+ [server]
257+ url = "https://custom.example.com"
258+ api_token = "test-token"
259+
260+ [upload]
261+ auto_upload = true
262+ upload_today_only = false
263+ retry_attempts = 5
264+ last_date_uploaded = 0
265+
266+ [formatting]
267+ number_comma = true
268+ number_human = false
269+ locale = "zh"
270+ decimal_places = 4
271+
272+ [models."custom-model"]
273+ pricing = { Flat = { input_per_1m = 10.0, output_per_1m = 20.0 } }
274+ caching = "None"
275+ is_estimated = true
276+
277+ [aliases]
278+ "my-alias" = "custom-model"
279+ "# ;
280+
281+ let config: Config = toml:: from_str ( toml_str) . unwrap ( ) ;
282+
283+ assert_eq ! ( config. server. url, "https://custom.example.com" ) ;
284+ assert ! ( config. models. contains_key( "custom-model" ) ) ;
285+
286+ let custom_model = config. models . get ( "custom-model" ) . unwrap ( ) ;
287+ match & custom_model. pricing {
288+ PricingStructure :: Flat {
289+ input_per_1m,
290+ output_per_1m,
291+ } => {
292+ assert_eq ! ( * input_per_1m, 10.0 ) ;
293+ assert_eq ! ( * output_per_1m, 20.0 ) ;
294+ }
295+ _ => panic ! ( "Expected flat pricing" ) ,
296+ }
297+
298+ assert_eq ! ( config. aliases. get( "my-alias" ) . unwrap( ) , "custom-model" ) ;
299+ }
300+
238301 #[ test]
239302 fn default_config_round_trip ( ) {
240303 let ( _dir, _path) = setup_test_config ( ) ;
0 commit comments