Skip to content

Commit 8d552e0

Browse files
committed
Rename ModelId to ModelName
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 284a869 commit 8d552e0

13 files changed

Lines changed: 76 additions & 71 deletions

File tree

bindings/python/amrs/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ class BasicModelConfig(BaseModel):
4747
)
4848

4949

50-
type ModelID = str
50+
type ModelName = str
5151

5252
class ModelConfig(BasicModelConfig):
53-
id: ModelID = Field(
53+
id: ModelName = Field(
5454
description="ID of the model to be used."
5555
)
5656
weight: Optional[int] = Field(
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import random
22

3-
from amrs.config import ModelID
3+
from amrs.config import ModelName
44
from amrs.router.router import Router
55

66
class RandomRouter(Router):
7-
def __init__(self, model_list: list[ModelID]):
7+
def __init__(self, model_list: list[ModelName]):
88
super().__init__(model_list)
99

10-
def sample(self, _: str) -> ModelID:
10+
def sample(self, _: str) -> ModelName:
1111
return random.choice(self._model_list)

bindings/python/amrs/router/router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ class ModelInfo:
77
average_latency: float = 0.0
88

99
class Router(abc.ABC):
10-
def __init__(self, model_list: list[config.ModelID]):
10+
def __init__(self, model_list: list[config.ModelName]):
1111
self._model_list = model_list
1212

1313
@abc.abstractmethod
14-
def sample(self, content: str) -> config.ModelID:
14+
def sample(self, content: str) -> config.ModelName:
1515
pass
1616

1717
def new_router(model_cfgs: list[config.ModelConfig], mode: config.RoutingMode) -> Router:

src/client/client.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use std::collections::HashMap;
22

3-
use crate::config::{Config, ModelId};
3+
use crate::config::{Config, ModelName};
44
use crate::provider::provider;
55
use crate::router::router;
66

77
pub struct Client {
8-
providers: HashMap<ModelId, Box<dyn provider::Provider>>,
8+
providers: HashMap<ModelName, Box<dyn provider::Provider>>,
99
router: Box<dyn router::Router>,
1010
}
1111

@@ -17,7 +17,7 @@ impl Client {
1717
let providers = cfg
1818
.models
1919
.iter()
20-
.map(|m| (m.id.clone(), provider::construct_provider(m.clone())))
20+
.map(|m| (m.name.clone(), provider::construct_provider(m.clone())))
2121
.collect();
2222

2323
Self {
@@ -30,8 +30,8 @@ impl Client {
3030
&mut self,
3131
request: provider::CreateResponseReq,
3232
) -> Result<provider::CreateResponseRes, provider::APIError> {
33-
let model_id = self.router.sample(&request);
34-
let provider = self.providers.get(&model_id).unwrap();
33+
let candidate = self.router.sample(&request);
34+
let provider = self.providers.get(&candidate).unwrap();
3535
provider.create_response(request).await
3636
}
3737
}
@@ -58,7 +58,7 @@ mod tests {
5858
config: Config::builder()
5959
.models(vec![
6060
ModelConfig::builder()
61-
.id("model_c".to_string())
61+
.name("model_c".to_string())
6262
.build()
6363
.unwrap(),
6464
])
@@ -72,14 +72,14 @@ mod tests {
7272
.routing_mode(RoutingMode::WRR)
7373
.models(vec![
7474
crate::config::ModelConfig::builder()
75-
.id("model_a".to_string())
75+
.name("model_a".to_string())
7676
.provider(Some("openai".to_string()))
7777
.base_url(Some("https://api.openai.com/v1".to_string()))
7878
.weight(1)
7979
.build()
8080
.unwrap(),
8181
crate::config::ModelConfig::builder()
82-
.id("model_b".to_string())
82+
.name("model_b".to_string())
8383
.provider(Some("openai".to_string()))
8484
.base_url(Some("https://api.openai.com/v1".to_string()))
8585
.weight(3)
@@ -95,13 +95,13 @@ mod tests {
9595
config: Config::builder()
9696
.models(vec![
9797
ModelConfig::builder()
98-
.id("model_a".to_string())
98+
.name("model_a".to_string())
9999
.provider(Some("openai".to_string()))
100100
.base_url(Some("https://api.openai.com/v1".to_string()))
101101
.build()
102102
.unwrap(),
103103
ModelConfig::builder()
104-
.id("model_b".to_string())
104+
.name("model_b".to_string())
105105
.provider(Some("openai".to_string()))
106106
.base_url(Some("https://api.openai.com/v1".to_string()))
107107
.build()

src/config.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ pub enum RoutingMode {
2929
}
3030

3131
// ------------------ Model Config ------------------
32-
pub type ModelId = String;
32+
pub type ModelName = String;
3333

3434
#[derive(Debug, Clone, Builder)]
3535
#[builder(build_fn(validate = "Self::validate"), pattern = "mutable")]
@@ -45,14 +45,14 @@ pub struct ModelConfig {
4545
pub(crate) max_output_tokens: Option<usize>,
4646

4747
#[builder(setter(custom))]
48-
pub(crate) id: ModelId,
48+
pub(crate) name: ModelName,
4949
#[builder(default=-1)]
5050
pub(crate) weight: i32,
5151
}
5252

5353
impl ModelConfigBuilder {
54-
pub fn id<S: AsRef<str>>(&mut self, name: S) -> &mut Self {
55-
self.id = Some(name.as_ref().to_string());
54+
pub fn name<S: AsRef<str>>(&mut self, name: S) -> &mut Self {
55+
self.name = Some(name.as_ref().to_string());
5656
self
5757
}
5858

@@ -65,8 +65,8 @@ impl ModelConfigBuilder {
6565
}
6666

6767
fn validate(&self) -> Result<(), String> {
68-
if self.id.is_none() {
69-
return Err("Model id must be provided.".to_string());
68+
if self.name.is_none() {
69+
return Err("Model name must be provided.".to_string());
7070
}
7171
Ok(())
7272
}
@@ -157,15 +157,15 @@ impl ConfigBuilder {
157157
{
158158
return Err(format!(
159159
"Model '{}' weight must be non-negative in Weighted routing mode.",
160-
model.id
160+
model.name
161161
));
162162
}
163163

164164
if let Some(max_output_tokens) = model.max_output_tokens {
165165
if max_output_tokens <= 0 {
166166
return Err(format!(
167167
"Model '{}' max_output_tokens must be positive.",
168-
model.id
168+
model.name
169169
));
170170
}
171171
}
@@ -174,7 +174,7 @@ impl ConfigBuilder {
174174
if temperature < 0.0 || temperature > 1.0 {
175175
return Err(format!(
176176
"Model '{}' temperature must be between 0.0 and 1.0.",
177-
model.id
177+
model.name
178178
));
179179
}
180180
}
@@ -229,7 +229,7 @@ mod tests {
229229
let valid_simplest_models_cfg = Config::builder()
230230
.model(
231231
ModelConfig::builder()
232-
.id("gpt-4".to_string())
232+
.name("gpt-4".to_string())
233233
.build()
234234
.unwrap(),
235235
)
@@ -259,11 +259,11 @@ mod tests {
259259
let valid_cfg = Config::builder()
260260
.models(vec![
261261
ModelConfig::builder()
262-
.id("gpt-3.5-turbo".to_string())
262+
.name("gpt-3.5-turbo".to_string())
263263
.build()
264264
.unwrap(),
265265
ModelConfig::builder()
266-
.id("gpt-4".to_string())
266+
.name("gpt-4".to_string())
267267
.build()
268268
.unwrap(),
269269
])
@@ -275,7 +275,7 @@ mod tests {
275275
let invalid_cfg_with_no_api_key = Config::builder()
276276
.model(
277277
ModelConfig::builder()
278-
.id("some-model".to_string())
278+
.name("some-model".to_string())
279279
.build()
280280
.unwrap(),
281281
)
@@ -290,7 +290,7 @@ mod tests {
290290
.max_output_tokens(2048)
291291
.model(
292292
ModelConfig::builder()
293-
.id("custom-model")
293+
.name("custom-model")
294294
.provider(Some("AMRS"))
295295
.build()
296296
.unwrap(),
@@ -303,7 +303,7 @@ mod tests {
303303
assert!(invalid_empty_models_cfg.is_err());
304304

305305
// case 6:
306-
print!("validating invalid empty model id config");
306+
print!("validating invalid empty model name config");
307307
let invalid_empty_model_id_cfg = ModelConfig::builder().build();
308308
assert!(invalid_empty_model_id_cfg.is_err());
309309
}
@@ -317,7 +317,7 @@ mod tests {
317317
.max_output_tokens(1500)
318318
.model(
319319
ModelConfig::builder()
320-
.id("model-1".to_string())
320+
.name("model-1".to_string())
321321
.build()
322322
.unwrap(),
323323
)
@@ -338,7 +338,7 @@ mod tests {
338338
let mut valid_specified_cfg = Config::builder()
339339
.provider("AMRS".to_string())
340340
.base_url("http://custom-api.ai".to_string())
341-
.model(ModelConfig::builder().id("model-2").build().unwrap())
341+
.model(ModelConfig::builder().name("model-2").build().unwrap())
342342
.build();
343343
valid_specified_cfg.as_mut().unwrap().populate();
344344

src/provider/fake.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ use async_openai::{Client, config::OpenAIConfig};
88
use async_trait::async_trait;
99
use reqwest::header::HeaderName;
1010

11-
use crate::config::{ModelConfig, ModelId};
11+
use crate::config::{ModelConfig, ModelName};
1212
use crate::provider::provider::{
1313
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
1414
};
1515

1616
pub struct FakeProvider {
17-
model: ModelId,
17+
model: ModelName,
1818
}
1919

2020
impl FakeProvider {
2121
pub fn new(config: ModelConfig) -> Self {
2222
Self {
23-
model: config.id.clone(),
23+
model: config.name.clone(),
2424
}
2525
}
2626
}

src/provider/openai.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ use async_openai::{Client, config::OpenAIConfig};
22
use async_trait::async_trait;
33
use derive_builder::Builder;
44

5-
use crate::config::{ModelConfig, ModelId};
5+
use crate::config::{ModelConfig, ModelName};
66
use crate::provider::provider::{
77
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
88
};
99

1010
#[derive(Debug, Clone, Builder)]
1111
#[builder(pattern = "mutable", build_fn(skip))]
1212
pub struct OpenAIProvider {
13-
model: ModelId,
13+
model: ModelName,
1414
config: OpenAIConfig,
1515
client: Client<OpenAIConfig>,
1616
}
@@ -28,7 +28,7 @@ impl OpenAIProvider {
2828
.with_api_key(api_key);
2929

3030
OpenAIProviderBuilder {
31-
model: Some(config.id.clone()),
31+
model: Some(config.name.clone()),
3232
config: Some(openai_config),
3333
client: None,
3434
}

src/provider/provider.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ pub trait Provider: Send + Sync {
3434
pub fn validate_request(request: &CreateResponseReq) -> Result<(), APIError> {
3535
if request.model.is_some() {
3636
return Err(APIError::InvalidArgument(
37-
"Model ID must be specified in the config".to_string(),
37+
"Model Name must be specified in the config".to_string(),
3838
));
3939
}
4040
Ok(())
@@ -56,7 +56,7 @@ mod tests {
5656
TestCase {
5757
name: "OpenAI Provider",
5858
config: ModelConfig::builder()
59-
.id("test-model".to_string())
59+
.name("test-model".to_string())
6060
.provider(Some("openai".to_string()))
6161
.base_url(Some("https://api.openai.com/v1".to_string()))
6262
.build()
@@ -66,7 +66,7 @@ mod tests {
6666
TestCase {
6767
name: "Unsupported Provider",
6868
config: ModelConfig::builder()
69-
.id("test-model".to_string())
69+
.name("test-model".to_string())
7070
.provider(Some("unsupported".to_string()))
7171
.base_url(Some("https://api.openai.com/v1".to_string()))
7272
.build()

src/router/random.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use rand::Rng;
22

3-
use crate::config::ModelId;
3+
use crate::config::ModelName;
44
use crate::provider::provider::CreateResponseReq;
55
use crate::router::router::{ModelInfo, Router};
66

@@ -19,10 +19,10 @@ impl Router for RandomRouter {
1919
"RandomRouter"
2020
}
2121

22-
fn sample(&mut self, _input: &CreateResponseReq) -> ModelId {
22+
fn sample(&mut self, _input: &CreateResponseReq) -> ModelName {
2323
let mut rng = rand::rng();
2424
let idx = rng.random_range(0..self.model_infos.len());
25-
self.model_infos[idx].id.clone()
25+
self.model_infos[idx].name.clone()
2626
}
2727
}
2828

@@ -34,24 +34,24 @@ mod tests {
3434
fn test_random_router_sampling() {
3535
let model_infos = vec![
3636
ModelInfo {
37-
id: "model_x".to_string(),
37+
name: "model_x".to_string(),
3838
weight: 1,
3939
},
4040
ModelInfo {
41-
id: "model_y".to_string(),
41+
name: "model_y".to_string(),
4242
weight: 2,
4343
},
4444
ModelInfo {
45-
id: "model_z".to_string(),
45+
name: "model_z".to_string(),
4646
weight: 3,
4747
},
4848
];
4949
let mut router = RandomRouter::new(model_infos.clone());
5050
let mut counts = std::collections::HashMap::new();
5151

5252
for _ in 0..1000 {
53-
let sampled_id = router.sample(&CreateResponseReq::default());
54-
*counts.entry(sampled_id.clone()).or_insert(0) += 1;
53+
let candidate = router.sample(&CreateResponseReq::default());
54+
*counts.entry(candidate.clone()).or_insert(0) += 1;
5555
}
5656
assert!(counts.len() == model_infos.len());
5757
for count in counts.values() {

0 commit comments

Comments
 (0)