Skip to content

Commit 7d9f5c3

Browse files
committed
openai: Support toolChoice
1 parent 8ac453c commit 7d9f5c3

1 file changed

Lines changed: 22 additions & 7 deletions

File tree

  • services/openai-dialog/src

services/openai-dialog/src/lib.rs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ use openai_api_rs::realtime::{
1515
api::RealtimeClient,
1616
client_event::{self, ClientEvent},
1717
server_event::{self, ServerEvent},
18-
types::{self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus},
18+
types::{
19+
self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus,
20+
ToolChoice,
21+
},
1922
};
2023
use serde::{Deserialize, Serialize};
2124
use tokio::{net::TcpStream, select};
@@ -42,6 +45,7 @@ pub struct Params {
4245
pub temperature: Option<f32>,
4346
#[serde(default)]
4447
pub tools: Vec<types::ToolDefinition>,
48+
tool_choice: Option<ToolChoice>,
4549
}
4650

4751
impl Params {
@@ -54,6 +58,7 @@ impl Params {
5458
voice: None,
5559
temperature: None,
5660
tools: vec![],
61+
tool_choice: None,
5762
}
5863
}
5964
}
@@ -113,6 +118,7 @@ pub enum ServiceInputEvent {
113118
Prompt {
114119
text: String,
115120
},
121+
#[serde(rename_all = "camelCase")]
116122
SessionUpdate {
117123
#[serde(skip_serializing_if = "Option::is_none")]
118124
instructions: Option<String>,
@@ -122,6 +128,8 @@ pub enum ServiceInputEvent {
122128
temperature: Option<f32>,
123129
#[serde(skip_serializing_if = "Option::is_none")]
124130
tools: Option<Vec<types::ToolDefinition>>,
131+
#[serde(skip_serializing_if = "Option::is_none")]
132+
tool_choice: Option<ToolChoice>,
125133
},
126134
}
127135

@@ -257,11 +265,6 @@ impl Client {
257265
send_update = true;
258266
};
259267

260-
if !params.tools.is_empty() {
261-
session.tools = Some(params.tools);
262-
send_update = true;
263-
}
264-
265268
if let Some(voice) = params.voice {
266269
session.voice = Some(voice);
267270
send_update = true;
@@ -272,6 +275,16 @@ impl Client {
272275
send_update = true;
273276
}
274277

278+
if !params.tools.is_empty() {
279+
session.tools = Some(params.tools);
280+
send_update = true;
281+
}
282+
283+
if let Some(tool_choice) = params.tool_choice {
284+
session.tool_choice = Some(tool_choice);
285+
send_update = true;
286+
}
287+
275288
if send_update {
276289
self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate {
277290
event_id: None,
@@ -426,13 +439,15 @@ impl Client {
426439
voice,
427440
temperature,
428441
tools,
442+
tool_choice,
429443
} => {
430444
let event = ClientEvent::SessionUpdate(client_event::SessionUpdate {
431445
session: types::Session {
432-
tools,
433446
instructions,
434447
voice,
435448
temperature,
449+
tools,
450+
tool_choice,
436451
..Default::default()
437452
},
438453
..Default::default()

0 commit comments

Comments
 (0)