Skip to content

Commit be254fe

Browse files
authored
Merge pull request #45 from pragmatrix/openai-update-session-ext
OpenAI Dialog: Extend Session Update Service Event to include all fields that make
2 parents c350095 + 7d9f5c3 commit be254fe

1 file changed

Lines changed: 36 additions & 7 deletions

File tree

  • services/openai-dialog/src

services/openai-dialog/src/lib.rs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ use openai_api_rs::realtime::{
1515
api::RealtimeClient,
1616
client_event::{self, ClientEvent},
1717
server_event::{self, ServerEvent},
18-
types::{self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus},
18+
types::{
19+
self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus,
20+
ToolChoice,
21+
},
1922
};
2023
use serde::{Deserialize, Serialize};
2124
use tokio::{net::TcpStream, select};
@@ -42,6 +45,7 @@ pub struct Params {
4245
pub temperature: Option<f32>,
4346
#[serde(default)]
4447
pub tools: Vec<types::ToolDefinition>,
48+
tool_choice: Option<ToolChoice>,
4549
}
4650

4751
impl Params {
@@ -54,6 +58,7 @@ impl Params {
5458
voice: None,
5559
temperature: None,
5660
tools: vec![],
61+
tool_choice: None,
5762
}
5863
}
5964
}
@@ -113,9 +118,18 @@ pub enum ServiceInputEvent {
113118
Prompt {
114119
text: String,
115120
},
121+
#[serde(rename_all = "camelCase")]
116122
SessionUpdate {
123+
#[serde(skip_serializing_if = "Option::is_none")]
124+
instructions: Option<String>,
125+
#[serde(skip_serializing_if = "Option::is_none")]
126+
voice: Option<RealtimeVoice>,
127+
#[serde(skip_serializing_if = "Option::is_none")]
128+
temperature: Option<f32>,
117129
#[serde(skip_serializing_if = "Option::is_none")]
118130
tools: Option<Vec<types::ToolDefinition>>,
131+
#[serde(skip_serializing_if = "Option::is_none")]
132+
tool_choice: Option<ToolChoice>,
119133
},
120134
}
121135

@@ -251,11 +265,6 @@ impl Client {
251265
send_update = true;
252266
};
253267

254-
if !params.tools.is_empty() {
255-
session.tools = Some(params.tools);
256-
send_update = true;
257-
}
258-
259268
if let Some(voice) = params.voice {
260269
session.voice = Some(voice);
261270
send_update = true;
@@ -266,6 +275,16 @@ impl Client {
266275
send_update = true;
267276
}
268277

278+
if !params.tools.is_empty() {
279+
session.tools = Some(params.tools);
280+
send_update = true;
281+
}
282+
283+
if let Some(tool_choice) = params.tool_choice {
284+
session.tool_choice = Some(tool_choice);
285+
send_update = true;
286+
}
287+
269288
if send_update {
270289
self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate {
271290
event_id: None,
@@ -415,10 +434,20 @@ impl Client {
415434
info!("Received prompt");
416435
self.push_prompt(PromptRequest(text)).await?;
417436
}
418-
ServiceInputEvent::SessionUpdate { tools } => {
437+
ServiceInputEvent::SessionUpdate {
438+
instructions,
439+
voice,
440+
temperature,
441+
tools,
442+
tool_choice,
443+
} => {
419444
let event = ClientEvent::SessionUpdate(client_event::SessionUpdate {
420445
session: types::Session {
446+
instructions,
447+
voice,
448+
temperature,
421449
tools,
450+
tool_choice,
422451
..Default::default()
423452
},
424453
..Default::default()

0 commit comments

Comments
 (0)