diff --git a/core/src/lib.rs b/core/src/lib.rs index 6e24baa..723e7e1 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -6,6 +6,7 @@ mod duration; pub mod language; mod protocol; mod registry; +mod segment; pub mod service; pub mod speech_gate; @@ -19,6 +20,7 @@ pub use conversation::*; pub use duration::Duration; pub use protocol::*; pub use registry::*; +pub use segment::*; pub use service::Service; /// A unidirectional audio message. Useful for implementing an audio transfer channel. diff --git a/core/src/segment.rs b/core/src/segment.rs new file mode 100644 index 0000000..4a80496 --- /dev/null +++ b/core/src/segment.rs @@ -0,0 +1,102 @@ +use std::collections::BTreeSet; + +use anyhow::Result; +use serde::{Deserialize, Serialize}; + +use crate::{ConversationOutput, OutputPath}; + +const SEGMENT_OUTPUT_PATH: OutputPath = OutputPath::Media; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "phase", rename_all = "camelCase")] +pub enum Segment { + UserSpeaking, + AssistantSpeaking, + Processing, + #[serde(rename_all = "camelCase")] + WaitingForFunctionResult { + call_ids: Vec, + }, + Idle, +} + +#[derive(Debug)] +pub struct SegmentController { + current: Option, + pending_function_call_ids: BTreeSet, + event_output: ConversationOutput, + event_mapper: fn(Segment) -> serde_json::Value, +} + +impl SegmentController { + pub fn new(output: ConversationOutput, event_mapper: fn(Segment) -> serde_json::Value) -> Self { + Self { + current: Some(Segment::Idle), + pending_function_call_ids: BTreeSet::new(), + event_output: output, + event_mapper, + } + } + + pub fn begin_user_speech(&mut self) -> Result> { + self.pending_function_call_ids.clear(); + self.transition(Segment::UserSpeaking) + } + + pub fn begin_processing(&mut self) -> Result> { + self.pending_function_call_ids.clear(); + self.transition(Segment::Processing) + } + + pub fn begin_assistant_speech(&mut self) -> Result> { + self.pending_function_call_ids.clear(); + self.transition(Segment::AssistantSpeaking) + } + + pub fn begin_function_wait(&mut self, call_ids: Vec) -> Result> { + self.pending_function_call_ids.extend(call_ids); + + if self.pending_function_call_ids.is_empty() { + return Ok(None); + } + + self.transition(Segment::WaitingForFunctionResult { + call_ids: self.pending_function_call_ids.iter().cloned().collect(), + }) + } + + pub fn end_function_wait(&mut self, call_id: &str) -> Result> { + self.pending_function_call_ids.remove(call_id); + + if self.pending_function_call_ids.is_empty() { + self.transition(Segment::Idle) + } else { + Ok(None) + } + } + + pub fn become_idle(&mut self) -> Result> { + self.pending_function_call_ids.clear(); + self.transition(Segment::Idle) + } + + pub fn is_idle(&self) -> bool { + self.current == Some(Segment::Idle) + } + + fn transition(&mut self, segment: Segment) -> Result> { + if self.current.as_ref() == Some(&segment) { + return Ok(None); + } + + self.current = Some(segment.clone()); + self.emit_segment_started(segment.clone())?; + Ok(Some(segment)) + } + + fn emit_segment_started(&self, segment: Segment) -> Result<()> { + self.event_output + .service_event(SEGMENT_OUTPUT_PATH, (self.event_mapper)(segment))?; + Ok(()) + } +} diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs index 58fb4ce..d882c0f 100644 --- a/examples/dialog_providers/google.rs +++ b/examples/dialog_providers/google.rs @@ -50,6 +50,10 @@ impl ProviderApi for GoogleProvider { fn parse_service_event(&self, value: serde_json::Value) -> Result> { match serde_json::from_value(value)? { + ServiceOutputEvent::SegmentStarted { segment } => { + tracing::info!(?segment, "Gemini segment started"); + Ok(None) + } ServiceOutputEvent::FunctionCall { name, call_id, diff --git a/examples/dialog_providers/openai.rs b/examples/dialog_providers/openai.rs index 3ed024b..27df37f 100644 --- a/examples/dialog_providers/openai.rs +++ b/examples/dialog_providers/openai.rs @@ -52,6 +52,10 @@ impl ProviderApi for OpenAIProvider { fn parse_service_event(&self, value: serde_json::Value) -> Result> { match serde_json::from_value(value)? { + OpenAIServiceOutputEvent::SegmentStarted { segment } => { + tracing::info!(?segment, "OpenAI segment started"); + Ok(None) + } OpenAIServiceOutputEvent::FunctionCall { name, call_id, diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index 1c12fcb..6bff919 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -19,6 +19,8 @@ use context_switch_core::{ ConversationInput, ConversationOutput, Input, OutputPath, }; +const INPUT_VAD_PEAK_THRESHOLD: u16 = 900; + #[derive(Debug)] pub struct Client { params: Params, @@ -37,7 +39,7 @@ impl Client { output: ConversationOutput, ) -> Result<()> { let billing_scope = self.params.model.clone(); - let mut state = ConversationState::new(); + let mut state = ConversationState::new(output.clone()); let mut session = match Session::connect(session_config(&self.params, text_outputs)?).await { Ok(session) => session, @@ -90,6 +92,11 @@ impl Client { match input { Input::Audio { frame } => { let mono = frame.into_mono(); + if state.segment_controller.is_idle() + && has_voice_activity(&mono.samples) + { + state.segment_controller.begin_user_speech()?; + } let sample_rate = mono.format.sample_rate; let audio = mono.to_le_bytes(); session @@ -104,12 +111,17 @@ impl Client { .context("Sending text to Gemini Live")?; } Input::ServiceEvent { value } => match serde_json::from_value(value)? { - ServiceInputEvent::FunctionCallResult { call_id, output } => { + ServiceInputEvent::FunctionCallResult { + call_id, + output: function_output, + } => { let Some(name) = state.tool_calls.resolve(&call_id)? else { return Ok(()); }; - let response = normalize_function_response(output); + state.segment_controller.end_function_wait(&call_id)?; + + let response = normalize_function_response(function_output); let response = FunctionResponse { id: call_id, @@ -154,20 +166,32 @@ impl Client { debug!(%text, "Gemini model text"); } ServerEvent::ModelAudio(audio) => { + if !state.suppress_assistant_until_turn_complete { + state.segment_controller.begin_assistant_speech()?; + } + let frame = AudioFrame::from_le_bytes(output_format, &audio); output.audio_frame(frame)?; } ServerEvent::GenerationComplete => {} ServerEvent::TurnComplete => { self.finalize_output_transcription(text_outputs, output, state)?; + state.segment_controller.become_idle()?; + state.suppress_assistant_until_turn_complete = false; output.request_completed(None)?; } ServerEvent::Interrupted => { // We expect a TurnComplete afterwards, so don't finalize the output transcription // when interrupted. + state.segment_controller.begin_user_speech()?; + state.suppress_assistant_until_turn_complete = true; output.clear_audio()?; } ServerEvent::InputTranscription(text) => { + if state.segment_controller.is_idle() { + state.segment_controller.begin_user_speech()?; + } + if self.params.input_audio_transcription { if text_outputs.text { output.text(true, text, None, None)?; @@ -183,6 +207,9 @@ impl Client { } ServerEvent::OutputTranscription(text) => { if self.params.output_audio_transcription { + if !state.suppress_assistant_until_turn_complete { + state.segment_controller.begin_assistant_speech()?; + } state.output_transcription_buffer.push_str(&text); if text_outputs.interim { output.text( @@ -202,6 +229,9 @@ impl Client { } } ServerEvent::ToolCall(calls) => { + self.finalize_output_transcription(text_outputs, output, state)?; + + let mut call_ids = Vec::new(); for call in calls { // Send the function call via the media path. // @@ -220,8 +250,11 @@ impl Client { }, )?; + call_ids.push(call.id.clone()); state.tool_calls.register(call.id, call.name)?; } + + state.segment_controller.begin_function_wait(call_ids)?; } ServerEvent::ToolCallCancellation(ids) => { // Since we are sending function calls through the media path, we need to send @@ -234,7 +267,9 @@ impl Client { }, )?; - state.tool_calls.cancel(id)?; + state.tool_calls.cancel(id.clone())?; + + state.segment_controller.end_function_wait(&id)?; } } ServerEvent::SessionResumption { .. } => {} @@ -268,6 +303,9 @@ impl Client { let buffer = mem::take(&mut state.output_transcription_buffer); if self.params.output_audio_transcription && text_outputs.text && !buffer.is_empty() { + if !state.suppress_assistant_until_turn_complete { + state.segment_controller.begin_assistant_speech()?; + } output.text(true, buffer, None, Some(AI_ASSISTANT_SPEAKER.into()))?; } Ok(()) @@ -282,6 +320,12 @@ fn normalize_function_response(output: serde_json::Value) -> serde_json::Value { } } +fn has_voice_activity(samples: &[i16]) -> bool { + samples + .iter() + .any(|sample| sample.unsigned_abs() >= INPUT_VAD_PEAK_THRESHOLD) +} + fn session_config(params: &Params, text_outputs: TextOutputs) -> Result { let transport = TransportConfig { endpoint: params diff --git a/services/google-dialog/src/conversation_state.rs b/services/google-dialog/src/conversation_state.rs index d873fae..e506580 100644 --- a/services/google-dialog/src/conversation_state.rs +++ b/services/google-dialog/src/conversation_state.rs @@ -1,11 +1,16 @@ use std::collections::hash_map; use anyhow::{Result, bail}; +use context_switch_core::{ConversationOutput, SegmentController}; use tracing::warn; +use crate::ServiceOutputEvent; + #[derive(Debug)] pub struct ConversationState { pub output_transcription_buffer: String, + pub segment_controller: SegmentController, + pub suppress_assistant_until_turn_complete: bool, pub tool_calls: ToolCallTracker, } @@ -21,9 +26,14 @@ enum ToolCallEntry { } impl ConversationState { - pub fn new() -> Self { + pub fn new(output: ConversationOutput) -> Self { Self { output_transcription_buffer: String::new(), + segment_controller: SegmentController::new( + output, + ServiceOutputEvent::segment_started_json, + ), + suppress_assistant_until_turn_complete: false, tool_calls: ToolCallTracker::default(), } } diff --git a/services/google-dialog/src/types.rs b/services/google-dialog/src/types.rs index 935fee8..c7f725f 100644 --- a/services/google-dialog/src/types.rs +++ b/services/google-dialog/src/types.rs @@ -1,4 +1,5 @@ use gemini_live::types::{FunctionDeclaration, RealtimeInputConfig, ThinkingLevel, Tool}; +use context_switch_core::Segment; use serde::{Deserialize, Deserializer, Serialize}; use anyhow::{Result, bail}; @@ -176,6 +177,8 @@ pub enum ServiceInputEvent { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum ServiceOutputEvent { + #[serde(rename_all = "camelCase")] + SegmentStarted { segment: Segment }, #[serde(rename_all = "camelCase")] FunctionCall { call_id: String, @@ -186,6 +189,13 @@ pub enum ServiceOutputEvent { ToolCallCancellation { call_id: String }, } +impl ServiceOutputEvent { + pub fn segment_started_json(segment: Segment) -> serde_json::Value { + serde_json::to_value(Self::SegmentStarted { segment }) + .expect("SegmentStarted serialization must succeed") + } +} + #[cfg(test)] mod tests { use super::Params; diff --git a/services/openai-dialog/src/client.rs b/services/openai-dialog/src/client.rs index 90756f1..f808fcc 100644 --- a/services/openai-dialog/src/client.rs +++ b/services/openai-dialog/src/client.rs @@ -1,5 +1,6 @@ #[cfg(feature = "prompt-delay")] use std::collections::VecDeque; +use std::collections::HashSet; use anyhow::{Context, Result, bail}; use base64::prelude::*; @@ -19,13 +20,15 @@ use crate::transcription::{TranscriptionSettings, TranscriptionState}; use crate::{Params, ServiceInputEvent, ServiceOutputEvent}; use context_switch_core::{ AI_ASSISTANT_SPEAKER, AudioFormat, AudioFrame, BillingRecord, BillingSchedule, - ConversationInput, ConversationOutput, Input, OutputPath, audio, + ConversationInput, ConversationOutput, Input, OutputPath, SegmentController, audio, }; pub struct Client { read: SplitStream>>, write: SplitSink>, Message>, + segment_controller: Option, transcription_state: TranscriptionState, + pending_input_processing_item_ids: HashSet, #[cfg(feature = "prompt-delay")] prompt_coordinator: PromptCoordinator, @@ -54,7 +57,9 @@ impl Client { Self { read, write, + segment_controller: None, transcription_state: TranscriptionState::default(), + pending_input_processing_item_ids: HashSet::new(), #[cfg(feature = "prompt-delay")] prompt_coordinator: PromptCoordinator::new(), } @@ -94,6 +99,11 @@ impl Client { debug!("Session created"); + self.segment_controller = Some(SegmentController::new( + output.clone(), + ServiceOutputEvent::segment_started_json, + )); + { let mut send_update = false; let mut session = types::RealtimeSession::default(); @@ -283,6 +293,12 @@ impl Client { Ok(()) } + fn segment_controller(&mut self) -> &mut SegmentController { + self.segment_controller + .as_mut() + .expect("segment controller must be initialized in dialog()") + } + async fn process_input(&mut self, input: Input) -> Result<()> { match input { Input::Text { .. } => { @@ -427,6 +443,7 @@ impl Client { self.handle_server_error(raw, &e)?; } ServerEvent::ResponseOutputAudioDelta(audio_delta) => { + self.segment_controller().begin_assistant_speech()?; let decoded = BASE64_STANDARD.decode(audio_delta.delta)?; let samples = audio::from_le_bytes(&decoded); trace!("Sending {} samples", samples.len()); @@ -436,7 +453,19 @@ impl Client { }; output.audio_frame(frame)?; } - ServerEvent::InputAudioBufferSpeechStarted(_) => output.clear_audio()?, + ServerEvent::InputAudioBufferSpeechStarted(_) => { + self.segment_controller().begin_user_speech()?; + output.clear_audio()?; + } + ServerEvent::InputAudioBufferSpeechStopped( + server_event::InputAudioBufferSpeechStopped { item_id, .. }, + ) => { + if transcription.input { + self.pending_input_processing_item_ids.insert(item_id); + } else { + self.segment_controller().begin_processing()?; + } + } ServerEvent::ConversationItemInputAudioTranscriptionDelta( server_event::ConversationItemInputAudioTranscriptionDelta { item_id, @@ -460,6 +489,7 @@ impl Client { .. }, ) => { + let processing_item_id = item_id.clone(); if transcription.input && let Some(text) = self.transcription_state.complete_input_transcription( item_id, @@ -469,6 +499,12 @@ impl Client { { output.text(true, text, None, None)?; } + if self + .pending_input_processing_item_ids + .remove(&processing_item_id) + { + self.segment_controller().begin_processing()?; + } } ServerEvent::ResponseOutputAudioTranscriptDelta( server_event::ResponseOutputAudioTranscriptDelta { @@ -516,6 +552,7 @@ impl Client { .. }) => { self.transcription_state.clear_item(&item_id); + self.pending_input_processing_item_ids.remove(&item_id); } ServerEvent::ConversationItemTruncated(server_event::ConversationItemTruncated { item_id, @@ -529,6 +566,9 @@ impl Client { response: types::Response { object, .. }, .. }) if object == "realtime.response" => { + self.pending_input_processing_item_ids.clear(); + self.segment_controller().begin_processing()?; + #[cfg(feature = "prompt-delay")] self.prompt_coordinator .update_response_state(&mut self.write, ResponseState::Responding) @@ -551,8 +591,7 @@ impl Client { // .transcription_state // .has_output_transcript_events_for_response(&response_id); - #[cfg(feature = "prompt-delay")] - let mut any_function_call_request = false; + let mut function_call_ids = Vec::new(); for item in items { trace!("Response Item: {item:?}"); match (&status, &item.r#type, &item.status) { @@ -592,10 +631,7 @@ impl Client { arguments, }, )?; - #[cfg(feature = "prompt-delay")] - { - any_function_call_request = true; - } + function_call_ids.push(call_id.clone()); } // Disabled fallback path: ignore transcript fields in `realtime.response` // and only finalize via `response.output_audio_transcript.done`. @@ -627,6 +663,11 @@ impl Client { } } + if !function_call_ids.is_empty() { + self.segment_controller() + .begin_function_wait(function_call_ids.clone())?; + } + if let Some(usage) = usage { let input_details = &usage.input_token_details; let output_details = &usage.output_token_details; @@ -667,12 +708,16 @@ impl Client { )?; } + if function_call_ids.is_empty() { + self.segment_controller().become_idle()?; + } + #[cfg(feature = "prompt-delay")] { self.prompt_coordinator .update_response_state( &mut self.write, - if any_function_call_request { + if !function_call_ids.is_empty() { ResponseState::ExpectingFunctionResult } else { ResponseState::Idle diff --git a/services/openai-dialog/src/types.rs b/services/openai-dialog/src/types.rs index 73b737d..43c475c 100644 --- a/services/openai-dialog/src/types.rs +++ b/services/openai-dialog/src/types.rs @@ -1,4 +1,5 @@ use openai_api_rs::realtime::types::{self, RealtimeVoice, ToolChoice}; +use context_switch_core::Segment; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] @@ -63,6 +64,8 @@ pub enum ServiceInputEvent { #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum ServiceOutputEvent { + #[serde(rename_all = "camelCase")] + SegmentStarted { segment: Segment }, #[serde(rename_all = "camelCase")] FunctionCall { call_id: String, @@ -77,3 +80,10 @@ pub enum ServiceOutputEvent { tools: Option>, }, } + +impl ServiceOutputEvent { + pub fn segment_started_json(segment: Segment) -> serde_json::Value { + serde_json::to_value(Self::SegmentStarted { segment }) + .expect("SegmentStarted serialization must succeed") + } +}