Skip to content

Commit f257c4a

Browse files
committed
Replace completion with chat completion
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent aa09c1a commit f257c4a

14 files changed

Lines changed: 57 additions & 38 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version = "0.1.0"
44
edition = "2024"
55

66
[dependencies]
7-
async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "completions"] }
7+
async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "chat-completion"] }
88
async-trait = "0.1.89"
99
derive_builder = "0.20.2"
1010
dotenvy = "0.15.7"

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode:
2727
// Before running the code, make sure to set your OpenAI API key in the environment variable:
2828
// export OPENAI_API_KEY="your_openai_api_key"
2929

30+
use tokio::runtime::Runtime;
3031
use arms::client;
31-
use arms::types::responses;
32+
use arms::types::chat;
3233

3334
let config = client::Config::builder()
3435
.provider("openai")
@@ -51,12 +52,15 @@ let config = client::Config::builder()
5152
.unwrap();
5253

5354
let mut client = client::Client::new(config);
54-
let request = responses::CreateResponseArgs::default()
55-
.input("give me a poem about nature")
55+
let request = chat::CreateChatCompletionRequestArgs::default()
56+
.messages([
57+
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
58+
chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?").into(),
59+
])
5660
.build()
5761
.unwrap();
5862

59-
let response = client.create_response(request).await.unwrap();
63+
let result = Runtime::new().unwrap().block_on(client.create_completion(request));
6064
```
6165

6266
## Contributing

src/client/client.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::client::config::{Config, ModelName};
44
use crate::provider::provider;
55
use crate::router::router;
66
use crate::types::error::OpenAIError;
7-
use crate::types::{completions, responses};
7+
use crate::types::{chat, responses};
88

99
pub struct Client {
1010
providers: HashMap<ModelName, Box<dyn provider::Provider>>,
@@ -37,10 +37,11 @@ impl Client {
3737
provider.create_response(request).await
3838
}
3939

40+
// This is chat completion endpoint.
4041
pub async fn create_completion(
4142
&mut self,
42-
request: completions::CreateCompletionRequest,
43-
) -> Result<completions::CreateCompletionResponse, OpenAIError> {
43+
request: chat::CreateChatCompletionRequest,
44+
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
4445
let candidate = self.router.sample();
4546
let provider = self.providers.get(&candidate).unwrap();
4647
provider.create_completion(request).await

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mod provider {
1414
pub mod provider;
1515
}
1616
pub mod types {
17-
pub mod completions;
17+
pub mod chat;
1818
pub mod error;
1919
pub mod responses;
2020
}

src/main.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use tokio::runtime::Runtime;
22

33
use arms::client;
4-
use arms::types::{completions, responses};
4+
use arms::types::chat;
55

66
fn main() {
77
// case 1: completion with DeepInfra provider.
@@ -27,8 +27,17 @@ fn main() {
2727

2828
let mut client = client::Client::new(config);
2929

30-
let request = completions::CreateCompletionRequestArgs::default()
31-
.prompt("How to achieve AGI?")
30+
let request = chat::CreateChatCompletionRequestArgs::default()
31+
.messages([
32+
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
33+
chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?")
34+
.into(),
35+
// chat::ChatCompletionRequestAssistantMessage::from(
36+
// "The Los Angeles Dodgers won the World Series in 2020.",
37+
// )
38+
// .into(),
39+
// chat::ChatCompletionRequestUserMessage::from("Where was it played?").into(),
40+
])
3241
.build()
3342
.unwrap();
3443

@@ -38,7 +47,7 @@ fn main() {
3847

3948
match result {
4049
Ok(response) => {
41-
println!("Response ID: {}", response.id);
50+
println!("Response: {:?}", response);
4251
}
4352
Err(e) => {
4453
eprintln!("Error: {}", e);

src/provider/common.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use crate::types::error::OpenAIError;
2-
use crate::types::{completions, responses};
2+
use crate::types::{chat, responses};
33

44
pub fn validate_completion_request(
5-
request: &completions::CreateCompletionRequest,
5+
request: &chat::CreateChatCompletionRequest,
66
) -> Result<(), OpenAIError> {
77
if request.model != "" {
88
return Err(OpenAIError::InvalidArgument(

src/provider/faker.rs

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
use async_openai::types::chat::Choice;
21
use async_trait::async_trait;
32

43
use crate::client::config::{ModelConfig, ModelName};
54
use crate::provider::{common, provider};
6-
use crate::types::completions::{CreateCompletionRequest, CreateCompletionResponse};
5+
use crate::types::chat;
76
use crate::types::error::OpenAIError;
87
use crate::types::responses::{
98
AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
@@ -76,22 +75,30 @@ impl provider::Provider for FakerProvider {
7675

7776
async fn create_completion(
7877
&self,
79-
_request: CreateCompletionRequest,
80-
) -> Result<CreateCompletionResponse, OpenAIError> {
81-
common::validate_completion_request(&_request)?;
82-
83-
Ok(CreateCompletionResponse {
78+
request: chat::CreateChatCompletionRequest,
79+
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
80+
common::validate_completion_request(&request)?;
81+
Ok(chat::CreateChatCompletionResponse {
8482
id: "fake-completion-id".to_string(),
8583
object: "text_completion".to_string(),
8684
created: 1_600_000_000,
8785
model: self.model.clone(),
88-
choices: vec![Choice {
86+
usage: None,
87+
service_tier: None,
88+
choices: vec![chat::ChatChoice {
8989
index: 0,
90-
text: "This is a fake completion.".to_string(),
91-
logprobs: None,
90+
message: chat::ChatCompletionResponseMessage {
91+
role: chat::Role::Assistant,
92+
content: Some("This is a fake chat completion.".to_string()),
93+
refusal: None,
94+
tool_calls: None,
95+
annotations: None,
96+
function_call: None,
97+
audio: None,
98+
},
9299
finish_reason: None,
100+
logprobs: None,
93101
}],
94-
usage: None,
95102
system_fingerprint: None,
96103
})
97104
}

src/provider/openai.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use derive_builder::Builder;
55
use crate::client::config::{DEFAULT_PROVIDER, ModelConfig, ModelName};
66
use crate::provider::{common, provider};
77
use crate::types::error::OpenAIError;
8-
use crate::types::{completions, responses};
8+
use crate::types::{chat, responses};
99

1010
#[derive(Debug, Clone, Builder)]
1111
#[builder(pattern = "mutable", build_fn(skip))]
@@ -65,14 +65,14 @@ impl provider::Provider for OpenAIProvider {
6565

6666
async fn create_completion(
6767
&self,
68-
request: completions::CreateCompletionRequest,
69-
) -> Result<completions::CreateCompletionResponse, OpenAIError> {
68+
request: chat::CreateChatCompletionRequest,
69+
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
7070
common::validate_completion_request(&request)?;
7171

7272
// Set the model after validation since model is bind to the provider.
7373
let mut req = request.clone();
7474
req.model = self.model.clone();
75-
self.client.completions().create(req).await
75+
self.client.chat().create(req).await
7676
}
7777

7878
async fn create_response(

src/provider/provider.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::client::config::ModelConfig;
44
use crate::provider::faker::FakerProvider;
55
use crate::provider::openai::OpenAIProvider;
66
use crate::types::error::OpenAIError;
7-
use crate::types::{completions, responses};
7+
use crate::types::{chat, responses};
88

99
// Not all providers support response endpoint.
1010
pub const RESPONSE_ENDPOINT_PROVIDERS: &[&str] = &["FAKER", "OPENAI"];
@@ -33,8 +33,8 @@ pub trait Provider: Send + Sync {
3333
) -> Result<responses::Response, OpenAIError>;
3434
async fn create_completion(
3535
&self,
36-
request: completions::CreateCompletionRequest,
37-
) -> Result<completions::CreateCompletionResponse, OpenAIError>;
36+
request: chat::CreateChatCompletionRequest,
37+
) -> Result<chat::CreateChatCompletionResponse, OpenAIError>;
3838
}
3939

4040
#[cfg(test)]

src/router/router.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use crate::client::config::{ModelConfig, ModelName, RoutingMode};
22
use crate::router::random::RandomRouter;
33
use crate::router::wrr::WeightedRoundRobinRouter;
4-
use crate::types::responses::CreateResponse;
54

65
#[derive(Debug, Clone)]
76
pub struct ModelInfo {

0 commit comments

Comments
 (0)