Skip to content

Commit 7ff0cfb

Browse files
committed
use async-openai fields wrapper
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 8d552e0 commit 7ff0cfb

14 files changed

Lines changed: 60 additions & 77 deletions

File tree

src/client/client.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use std::collections::HashMap;
22

3-
use crate::config::{Config, ModelName};
3+
use crate::client::config::{Config, ModelName};
44
use crate::provider::provider;
55
use crate::router::router;
6+
use crate::types::error::OpenAIError;
7+
use crate::types::responses::{CreateResponse, Response};
68

79
pub struct Client {
810
providers: HashMap<ModelName, Box<dyn provider::Provider>>,
@@ -28,8 +30,8 @@ impl Client {
2830

2931
pub async fn create_response(
3032
&mut self,
31-
request: provider::CreateResponseReq,
32-
) -> Result<provider::CreateResponseRes, provider::APIError> {
33+
request: CreateResponse,
34+
) -> Result<Response, OpenAIError> {
3335
let candidate = self.router.sample(&request);
3436
let provider = self.providers.get(&candidate).unwrap();
3537
provider.create_response(request).await
@@ -39,7 +41,7 @@ impl Client {
3941
#[cfg(test)]
4042
mod tests {
4143
use super::*;
42-
use crate::config::{Config, ModelConfig, RoutingMode};
44+
use crate::client::config::{Config, ModelConfig, RoutingMode};
4345
use dotenvy::from_filename;
4446

4547
#[test]
@@ -71,14 +73,14 @@ mod tests {
7173
config: Config::builder()
7274
.routing_mode(RoutingMode::WRR)
7375
.models(vec![
74-
crate::config::ModelConfig::builder()
76+
crate::client::config::ModelConfig::builder()
7577
.name("model_a".to_string())
7678
.provider(Some("openai".to_string()))
7779
.base_url(Some("https://api.openai.com/v1".to_string()))
7880
.weight(1)
7981
.build()
8082
.unwrap(),
81-
crate::config::ModelConfig::builder()
83+
crate::client::config::ModelConfig::builder()
8284
.name("model_b".to_string())
8385
.provider(Some("openai".to_string()))
8486
.base_url(Some("https://api.openai.com/v1".to_string()))

src/config.rs renamed to src/client/config.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use derive_builder::Builder;
55
use lazy_static::lazy_static;
66

77
// ------------------ Provider ------------------
8-
pub type ProviderName = String;
98
const OPENAI_PROVIDER: &str = "OPENAI";
109

1110
lazy_static! {
@@ -38,7 +37,7 @@ pub struct ModelConfig {
3837
#[builder(default = "None")]
3938
pub(crate) base_url: Option<String>,
4039
#[builder(default = "None", setter(custom))]
41-
pub(crate) provider: Option<ProviderName>,
40+
pub(crate) provider: Option<String>,
4241
#[builder(default = "None")]
4342
pub(crate) temperature: Option<f32>,
4443
#[builder(default = "None")]
@@ -85,8 +84,8 @@ pub struct Config {
8584
// global configs for models, will be overridden by model-specific configs
8685
#[builder(default = "https://api.openai.com/v1".to_string())]
8786
pub(crate) base_url: String,
88-
#[builder(default = "ProviderName::from(OPENAI_PROVIDER)", setter(custom))]
89-
pub(crate) provider: ProviderName,
87+
#[builder(default = "OPENAI_PROVIDER.to_string()", setter(custom))]
88+
pub(crate) provider: String,
9089
#[builder(default = "0.8")]
9190
pub(crate) temperature: f32,
9291
#[builder(default = "1024")]
@@ -195,15 +194,15 @@ impl ConfigBuilder {
195194
"{}_API_KEY",
196195
self.provider
197196
.as_ref()
198-
.unwrap_or(&ProviderName::from(OPENAI_PROVIDER))
197+
.unwrap_or(&OPENAI_PROVIDER.to_string())
199198
.to_uppercase()
200199
);
201200
if env::var(&env_var).is_err() {
202201
return Err(format!(
203202
"API key for provider '{}' not found in environment variable '{}'",
204203
self.provider
205204
.as_ref()
206-
.unwrap_or(&ProviderName::from(OPENAI_PROVIDER))
205+
.unwrap_or(&OPENAI_PROVIDER.to_string())
207206
.to_uppercase(),
208207
env_var
209208
));

src/client/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pub mod client;
2+
pub mod config;
3+
4+
pub use client::Client;
5+
pub use config::{Config, ModelConfig, ModelName, RoutingMode};

src/lib.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
1+
pub mod client;
2+
13
mod router {
24
mod random;
35
pub mod router;
46
pub mod stats;
57
mod wrr;
68
}
7-
mod config;
8-
mod client {
9-
pub mod client;
10-
}
9+
1110
mod provider {
1211
mod fake;
1312
mod openai;
1413
pub mod provider;
1514
}
16-
17-
pub use crate::client::client::Client;
18-
pub use crate::config::{Config, ModelConfig, RoutingMode};
19-
pub use crate::provider::provider::{
20-
APIError, CreateResponseArgs, CreateResponseReq, CreateResponseRes,
21-
};
15+
pub mod types {
16+
pub mod error;
17+
pub mod responses;
18+
}

src/provider/fake.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
1-
use std::str::FromStr;
2-
3-
use async_openai::types::responses::{
4-
AssistantRole, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
5-
OutputTextContent, Status,
6-
};
7-
use async_openai::{Client, config::OpenAIConfig};
81
use async_trait::async_trait;
9-
use reqwest::header::HeaderName;
102

11-
use crate::config::{ModelConfig, ModelName};
12-
use crate::provider::provider::{
13-
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
3+
use crate::client::config::{ModelConfig, ModelName};
4+
use crate::provider::provider::{Provider, validate_request};
5+
use crate::types::error::OpenAIError;
6+
use crate::types::responses::{
7+
AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
8+
OutputTextContent, Response, Status,
149
};
1510

1611
pub struct FakeProvider {
@@ -31,13 +26,10 @@ impl Provider for FakeProvider {
3126
"FakeProvider"
3227
}
3328

34-
async fn create_response(
35-
&self,
36-
request: CreateResponseReq,
37-
) -> Result<CreateResponseRes, APIError> {
29+
async fn create_response(&self, request: CreateResponse) -> Result<Response, OpenAIError> {
3830
validate_request(&request)?;
3931

40-
Ok(CreateResponseRes {
32+
Ok(Response {
4133
id: "fake-response-id".to_string(),
4234
object: "text_completion".to_string(),
4335
model: self.model.clone(),

src/provider/openai.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use async_openai::{Client, config::OpenAIConfig};
22
use async_trait::async_trait;
33
use derive_builder::Builder;
44

5-
use crate::config::{ModelConfig, ModelName};
6-
use crate::provider::provider::{
7-
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
8-
};
5+
use crate::client::config::{ModelConfig, ModelName};
6+
use crate::provider::provider::{Provider, validate_request};
7+
use crate::types::error::OpenAIError;
8+
use crate::types::responses::{CreateResponse, Response};
99

1010
#[derive(Debug, Clone, Builder)]
1111
#[builder(pattern = "mutable", build_fn(skip))]
@@ -51,10 +51,7 @@ impl Provider for OpenAIProvider {
5151
"OpenAIProvider"
5252
}
5353

54-
async fn create_response(
55-
&self,
56-
request: CreateResponseReq,
57-
) -> Result<CreateResponseRes, APIError> {
54+
async fn create_response(&self, request: CreateResponse) -> Result<Response, OpenAIError> {
5855
validate_request(&request)?;
5956
self.client.responses().create(request).await
6057
}

src/provider/provider.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,10 @@
1-
use async_openai::error::OpenAIError as OpenAI_Error;
2-
use async_openai::types::responses::{
3-
CreateResponse, CreateResponseArgs as OpenAICreateResponseArgs, Response,
4-
};
51
use async_trait::async_trait;
62

7-
use crate::config::ModelConfig;
3+
use crate::client::config::ModelConfig;
84
use crate::provider::fake::FakeProvider;
95
use crate::provider::openai::OpenAIProvider;
10-
11-
pub type CreateResponseReq = CreateResponse;
12-
pub type CreateResponseArgs = OpenAICreateResponseArgs;
13-
pub type CreateResponseRes = Response;
14-
pub type APIError = OpenAI_Error;
6+
use crate::types::error::OpenAIError;
7+
use crate::types::responses::{CreateResponse, Response};
158

169
pub fn construct_provider(config: ModelConfig) -> Box<dyn Provider> {
1710
let provider = config.provider.as_ref().unwrap();
@@ -25,15 +18,12 @@ pub fn construct_provider(config: ModelConfig) -> Box<dyn Provider> {
2518
#[async_trait]
2619
pub trait Provider: Send + Sync {
2720
fn name(&self) -> &'static str;
28-
async fn create_response(
29-
&self,
30-
request: CreateResponseReq,
31-
) -> Result<CreateResponseRes, APIError>;
21+
async fn create_response(&self, request: CreateResponse) -> Result<Response, OpenAIError>;
3222
}
3323

34-
pub fn validate_request(request: &CreateResponseReq) -> Result<(), APIError> {
24+
pub fn validate_request(request: &CreateResponse) -> Result<(), OpenAIError> {
3525
if request.model.is_some() {
36-
return Err(APIError::InvalidArgument(
26+
return Err(OpenAIError::InvalidArgument(
3727
"Model Name must be specified in the config".to_string(),
3828
));
3929
}

src/router/random.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use rand::Rng;
22

3-
use crate::config::ModelName;
4-
use crate::provider::provider::CreateResponseReq;
3+
use crate::client::config::ModelName;
54
use crate::router::router::{ModelInfo, Router};
5+
use crate::types::responses::CreateResponse;
66

77
pub struct RandomRouter {
88
pub model_infos: Vec<ModelInfo>,
@@ -19,7 +19,7 @@ impl Router for RandomRouter {
1919
"RandomRouter"
2020
}
2121

22-
fn sample(&mut self, _input: &CreateResponseReq) -> ModelName {
22+
fn sample(&mut self, _input: &CreateResponse) -> ModelName {
2323
let mut rng = rand::rng();
2424
let idx = rng.random_range(0..self.model_infos.len());
2525
self.model_infos[idx].name.clone()
@@ -50,7 +50,7 @@ mod tests {
5050
let mut counts = std::collections::HashMap::new();
5151

5252
for _ in 0..1000 {
53-
let candidate = router.sample(&CreateResponseReq::default());
53+
let candidate = router.sample(&CreateResponse::default());
5454
*counts.entry(candidate.clone()).or_insert(0) += 1;
5555
}
5656
assert!(counts.len() == model_infos.len());

src/router/router.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
use std::collections::HashMap;
2-
use std::sync::atomic::AtomicUsize;
3-
4-
use crate::config::{ModelConfig, ModelName, RoutingMode};
5-
use crate::provider::provider::CreateResponseReq;
1+
use crate::client::config::{ModelConfig, ModelName, RoutingMode};
62
use crate::router::random::RandomRouter;
73
use crate::router::wrr::WeightedRoundRobinRouter;
4+
use crate::types::responses::CreateResponse;
85

96
#[derive(Debug, Clone)]
107
pub struct ModelInfo {
@@ -28,7 +25,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec<ModelConfig>) -> Box<dyn
2825

2926
pub trait Router {
3027
fn name(&self) -> &'static str;
31-
fn sample(&mut self, input: &CreateResponseReq) -> ModelName;
28+
fn sample(&mut self, input: &CreateResponse) -> ModelName;
3229
}
3330

3431
#[cfg(test)]

src/router/stats.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::collections::HashMap;
22
use std::sync::atomic::{AtomicUsize, Ordering};
33

4-
use crate::config::ModelName;
4+
use crate::client::config::ModelName;
55

66
pub struct RouterStats {
77
requests_per_model: HashMap<ModelName, AtomicUsize>,

0 commit comments

Comments
 (0)