Skip to content

Commit d6f7739

Browse files
jimyagCopilot
andauthored
feat: support external model and pricing configuration from TOML (#140)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 43eacfe commit d6f7739

5 files changed

Lines changed: 1426 additions & 949 deletions

File tree

src/config.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use crate::models::ModelInfo;
12
use anyhow::{Context, Result};
23
use serde::{Deserialize, Serialize};
34
use std::cell::RefCell;
5+
use std::collections::HashMap;
46
use std::fs;
57
use std::path::PathBuf;
68

@@ -9,6 +11,10 @@ pub struct Config {
911
pub server: ServerConfig,
1012
pub upload: UploadConfig,
1113
pub formatting: FormattingConfig,
14+
#[serde(default)]
15+
pub models: HashMap<String, ModelInfo>,
16+
#[serde(default)]
17+
pub aliases: HashMap<String, String>,
1218
}
1319

1420
#[derive(Debug, Serialize, Deserialize, Clone)]
@@ -52,6 +58,8 @@ impl Default for Config {
5258
locale: "en".to_string(),
5359
decimal_places: 2,
5460
},
61+
models: HashMap::new(),
62+
aliases: HashMap::new(),
5563
}
5664
}
5765
}
@@ -171,6 +179,12 @@ pub fn show_config() -> Result<()> {
171179
println!(" Number Human: {}", config.formatting.number_human);
172180
println!(" Locale: {}", config.formatting.locale);
173181
println!(" Decimal Places: {}", config.formatting.decimal_places);
182+
if !config.models.is_empty() {
183+
println!(" Custom Models: {}", config.models.len());
184+
}
185+
if !config.aliases.is_empty() {
186+
println!(" Custom Aliases: {}", config.aliases.len());
187+
}
174188
}
175189
None => {
176190
println!("❌ No configuration file found.");
@@ -226,6 +240,7 @@ pub fn set_config_value(key: &str, value: &str) -> Result<()> {
226240
#[cfg(test)]
227241
mod tests {
228242
use super::*;
243+
use crate::models::PricingStructure;
229244
use tempfile::TempDir;
230245

231246
fn setup_test_config() -> (TempDir, PathBuf) {
@@ -235,6 +250,54 @@ mod tests {
235250
(dir, config_path)
236251
}
237252

253+
#[test]
254+
fn test_config_with_custom_models() {
255+
let toml_str = r#"
256+
[server]
257+
url = "https://custom.example.com"
258+
api_token = "test-token"
259+
260+
[upload]
261+
auto_upload = true
262+
upload_today_only = false
263+
retry_attempts = 5
264+
last_date_uploaded = 0
265+
266+
[formatting]
267+
number_comma = true
268+
number_human = false
269+
locale = "zh"
270+
decimal_places = 4
271+
272+
[models."custom-model"]
273+
pricing = { Flat = { input_per_1m = 10.0, output_per_1m = 20.0 } }
274+
caching = "None"
275+
is_estimated = true
276+
277+
[aliases]
278+
"my-alias" = "custom-model"
279+
"#;
280+
281+
let config: Config = toml::from_str(toml_str).unwrap();
282+
283+
assert_eq!(config.server.url, "https://custom.example.com");
284+
assert!(config.models.contains_key("custom-model"));
285+
286+
let custom_model = config.models.get("custom-model").unwrap();
287+
match &custom_model.pricing {
288+
PricingStructure::Flat {
289+
input_per_1m,
290+
output_per_1m,
291+
} => {
292+
assert_eq!(*input_per_1m, 10.0);
293+
assert_eq!(*output_per_1m, 20.0);
294+
}
295+
_ => panic!("Expected flat pricing"),
296+
}
297+
298+
assert_eq!(config.aliases.get("my-alias").unwrap(), "custom-model");
299+
}
300+
238301
#[test]
239302
fn default_config_round_trip() {
240303
let (_dir, _path) = setup_test_config();

src/main.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ async fn main() {
131131
// Load config file to get defaults
132132
let config = config::Config::load().unwrap_or(None).unwrap_or_default();
133133

134+
// Initialize external models from config
135+
models::init_external_models(config.models.clone(), config.aliases.clone());
136+
134137
// Create format options merging config defaults with CLI overrides
135138
let format_options = utils::NumberFormatOptions {
136139
use_comma: cli.number_comma || config.formatting.number_comma,

0 commit comments

Comments
 (0)