Skip to content

Commit 3d7ccae

Browse files
committed
polish API
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 7ff0cfb commit 3d7ccae

11 files changed

Lines changed: 220 additions & 69 deletions

File tree

.env.integration-test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
AMRS_API_KEY=your_amrs_api_key_here
22
OPENAI_API_KEY=your_openai_api_key_here
3-
FAKE_API_KEY=your_fake_api_key_here
3+
FAKER_API_KEY=your_faker_api_key_here

Cargo.lock

Lines changed: 59 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ lazy_static = "1.5.0"
1212
rand = "0.9.2"
1313
reqwest = "0.12.26"
1414
serde = "1.0.228"
15-
tokio = "1.48.0"
15+
tokio = { version = "1.48.0", features = ["full"] }

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Thanks to [async-openai](https://github.com/64bit/async-openai), AMRS builds on
1111
- Flexible routing strategies, including:
1212
- **Random**: Randomly selects a model from the available models.
1313
- **WRR**: Weighted Round Robin selects models based on predefined weights.
14-
- **UCB**: Upper Confidence Bound based model selection (coming soon).
14+
- **UCB1**: Upper Confidence Bound based model selection (coming soon).
1515
- **Adaptive**: Dynamically selects models based on performance metrics (coming soon).
1616

1717
- Broad provider support:
@@ -27,20 +27,21 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode:
2727
// Before running the code, make sure to set your OpenAI API key in the environment variable:
2828
// export OPENAI_API_KEY="your_openai_api_key"
2929

30-
use arms::{Client, Config, ModelConfig, CreateResponseArgs, RoutingMode};
30+
use arms::client;
31+
use arms::types::responses;
3132

32-
let config = Config::builder()
33+
let config = client::Config::builder()
3334
.provider("openai")
34-
.routing_mode(RoutingMode::WRR)
35+
.routing_mode(client::RoutingMode::WRR)
3536
.model(
36-
ModelConfig::builder()
37+
client::ModelConfig::builder()
3738
.id("gpt-3.5-turbo")
3839
.weight(2)
3940
.build()
4041
.unwrap(),
4142
)
4243
.model(
43-
ModelConfig::builder()
44+
client::ModelConfig::builder()
4445
.id("gpt-4")
4546
.weight(1)
4647
.build()
@@ -49,8 +50,8 @@ let config = Config::builder()
4950
.build()
5051
.unwrap();
5152

52-
let mut client = Client::new(config);
53-
let request = CreateResponseArgs::default()
53+
let mut client = client::Client::new(config);
54+
let request = responses::CreateResponseArgs::default()
5455
.input("give me a poem about nature")
5556
.build()
5657
.unwrap();

src/client/config.rs

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use derive_builder::Builder;
55
use lazy_static::lazy_static;
66

77
// ------------------ Provider ------------------
8-
const OPENAI_PROVIDER: &str = "OPENAI";
8+
pub const DEFAULT_PROVIDER: &str = "OPENAI";
99

1010
lazy_static! {
1111
pub static ref PROVIDER_BASE_URLS: HashMap<&'static str, &'static str> = {
@@ -14,7 +14,7 @@ lazy_static! {
1414
m.insert("DEEPINFRA", "https://api.deepinfra.com/v1/openai");
1515
m.insert("OPENROUTER", "https://openrouter.ai/api/v1");
1616

17-
m.insert("FAKE", "http://localhost:8080"); // test only
17+
m.insert("FAKER", "http://localhost:8080"); // test only
1818
// TODO: support more providers here...
1919
m
2020
};
@@ -82,9 +82,9 @@ impl ModelConfig {
8282
#[builder(build_fn(validate = "Self::validate"), pattern = "mutable")]
8383
pub struct Config {
8484
// global configs for models, will be overridden by model-specific configs
85-
#[builder(default = "https://api.openai.com/v1".to_string())]
86-
pub(crate) base_url: String,
87-
#[builder(default = "OPENAI_PROVIDER.to_string()", setter(custom))]
85+
#[builder(default=None, setter(custom))]
86+
pub(crate) base_url: Option<String>,
87+
#[builder(default = "DEFAULT_PROVIDER.to_string()", setter(custom))]
8888
pub(crate) provider: String,
8989
#[builder(default = "0.8")]
9090
pub(crate) temperature: f32,
@@ -104,22 +104,34 @@ impl Config {
104104

105105
// populate will fill in the missing model-specific configs with global configs.
106106
pub fn populate(&mut self) -> &mut Self {
107+
let global_base_url = match self.base_url.is_some() {
108+
true => self.base_url.clone(),
109+
false => Some(
110+
PROVIDER_BASE_URLS
111+
.get(self.provider.as_str())
112+
.unwrap()
113+
.to_string(),
114+
),
115+
};
116+
107117
for model in &mut self.models {
108-
let model_url_exist = model.base_url.is_some();
118+
if model.base_url.is_none() {
119+
if model.provider.is_some() {
120+
model.base_url = Some(
121+
PROVIDER_BASE_URLS
122+
.get(model.provider.as_ref().unwrap().as_str())
123+
.unwrap()
124+
.to_string(),
125+
);
126+
} else {
127+
model.base_url = global_base_url.clone();
128+
}
129+
}
109130

110131
if model.provider.is_none() {
111132
model.provider = Some(self.provider.clone());
112133
}
113134

114-
if !model_url_exist
115-
&& PROVIDER_BASE_URLS.contains_key(model.provider.as_ref().unwrap().as_str())
116-
{
117-
model.base_url =
118-
Some(PROVIDER_BASE_URLS[model.provider.as_ref().unwrap().as_str()].to_string());
119-
}
120-
if !model_url_exist {
121-
model.base_url = Some(self.base_url.clone());
122-
}
123135
if model.temperature.is_none() {
124136
model.temperature = Some(self.temperature);
125137
}
@@ -132,6 +144,10 @@ impl Config {
132144
}
133145

134146
impl ConfigBuilder {
147+
pub fn base_url<S: AsRef<str>>(&mut self, url: S) -> &mut Self {
148+
self.base_url = Some(Some(url.as_ref().to_string()));
149+
self
150+
}
135151
pub fn model(&mut self, model: ModelConfig) -> &mut Self {
136152
let mut models = self.models.clone().unwrap_or_default();
137153
models.push(model);
@@ -180,7 +196,7 @@ impl ConfigBuilder {
180196

181197
// check the existence of API key in environment variables
182198
if let Some(provider) = &model.provider {
183-
let env_var = format!("{}_API_KEY", provider.to_uppercase());
199+
let env_var = format!("{}_API_KEY", provider);
184200
if env::var(&env_var).is_err() {
185201
return Err(format!(
186202
"API key for provider '{}' not found in environment variable '{}'",
@@ -194,16 +210,15 @@ impl ConfigBuilder {
194210
"{}_API_KEY",
195211
self.provider
196212
.as_ref()
197-
.unwrap_or(&OPENAI_PROVIDER.to_string())
213+
.unwrap_or(&DEFAULT_PROVIDER.to_string())
198214
.to_uppercase()
199215
);
200216
if env::var(&env_var).is_err() {
201217
return Err(format!(
202218
"API key for provider '{}' not found in environment variable '{}'",
203219
self.provider
204220
.as_ref()
205-
.unwrap_or(&OPENAI_PROVIDER.to_string())
206-
.to_uppercase(),
221+
.unwrap_or(&DEFAULT_PROVIDER.to_string()),
207222
env_var
208223
));
209224
}
@@ -234,10 +249,8 @@ mod tests {
234249
)
235250
.build();
236251
assert!(valid_simplest_models_cfg.is_ok());
237-
assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == OPENAI_PROVIDER);
238-
assert!(
239-
valid_simplest_models_cfg.as_ref().unwrap().base_url == "https://api.openai.com/v1"
240-
);
252+
assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER);
253+
assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None);
241254
assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8);
242255
assert!(
243256
valid_simplest_models_cfg
@@ -285,7 +298,7 @@ mod tests {
285298
// case 4:
286299
// AMRS_API_KEY is set in .env.test already.
287300
let valid_cfg_with_customized_provider = Config::builder()
288-
.base_url("http://example.ai".to_string())
301+
.base_url("http://example.ai")
289302
.max_output_tokens(2048)
290303
.model(
291304
ModelConfig::builder()

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ mod router {
88
}
99

1010
mod provider {
11-
mod fake;
11+
mod faker;
1212
mod openai;
1313
pub mod provider;
1414
}

src/main.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,52 @@
1+
use tokio::runtime::Runtime;
2+
3+
use arms::client;
4+
use arms::types::responses;
5+
16
fn main() {
2-
println!("Hello, world!");
7+
let config = client::Config::builder()
8+
.provider("deepinfra")
9+
.routing_mode(client::RoutingMode::WRR)
10+
.model(
11+
client::ModelConfig::builder()
12+
.name("nvidia/Nemotron-3-Nano-30B-A3B")
13+
.weight(1)
14+
.build()
15+
.unwrap(),
16+
)
17+
.model(
18+
client::ModelConfig::builder()
19+
.name("deepseek-ai/DeepSeek-V3.2")
20+
.weight(2)
21+
.build()
22+
.unwrap(),
23+
)
24+
.build()
25+
.unwrap();
26+
27+
let mut client = client::Client::new(config);
28+
29+
let request = responses::CreateResponseArgs::default()
30+
.input(responses::InputParam::Items(vec![
31+
responses::InputItem::EasyMessage(responses::EasyInputMessage {
32+
r#type: responses::MessageType::Message,
33+
role: responses::Role::User,
34+
content: responses::EasyInputContent::Text("What is AGI?".to_string()),
35+
}),
36+
]))
37+
.build()
38+
.unwrap();
39+
40+
let result = Runtime::new()
41+
.unwrap()
42+
.block_on(client.create_response(request));
43+
44+
match result {
45+
Ok(response) => {
46+
println!("Response ID: {}", response.id);
47+
}
48+
Err(e) => {
49+
eprintln!("Error: {}", e);
50+
}
51+
}
352
}
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
use async_trait::async_trait;
22

33
use crate::client::config::{ModelConfig, ModelName};
4-
use crate::provider::provider::{Provider, validate_request};
4+
use crate::provider::provider;
55
use crate::types::error::OpenAIError;
66
use crate::types::responses::{
77
AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
88
OutputTextContent, Response, Status,
99
};
1010

11-
pub struct FakeProvider {
11+
pub struct FakerProvider {
1212
model: ModelName,
1313
}
1414

15-
impl FakeProvider {
15+
impl FakerProvider {
1616
pub fn new(config: ModelConfig) -> Self {
1717
Self {
1818
model: config.name.clone(),
@@ -21,13 +21,13 @@ impl FakeProvider {
2121
}
2222

2323
#[async_trait]
24-
impl Provider for FakeProvider {
24+
impl provider::Provider for FakerProvider {
2525
fn name(&self) -> &'static str {
2626
"FakeProvider"
2727
}
2828

2929
async fn create_response(&self, request: CreateResponse) -> Result<Response, OpenAIError> {
30-
validate_request(&request)?;
30+
provider::validate_responses_request(&request)?;
3131

3232
Ok(Response {
3333
id: "fake-response-id".to_string(),

0 commit comments

Comments
 (0)