Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod duration;
pub mod language;
mod protocol;
mod registry;
mod segment;
pub mod service;
pub mod speech_gate;

Expand All @@ -19,6 +20,7 @@ pub use conversation::*;
pub use duration::Duration;
pub use protocol::*;
pub use registry::*;
pub use segment::*;
pub use service::Service;

/// A unidirectional audio message. Useful for implementing an audio transfer channel.
Expand Down
102 changes: 102 additions & 0 deletions core/src/segment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use std::collections::BTreeSet;

use anyhow::Result;
use serde::{Deserialize, Serialize};

use crate::{ConversationOutput, OutputPath};

const SEGMENT_OUTPUT_PATH: OutputPath = OutputPath::Media;

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "phase", rename_all = "camelCase")]
pub enum Segment {
UserSpeaking,
AssistantSpeaking,
Processing,
#[serde(rename_all = "camelCase")]
WaitingForFunctionResult {
call_ids: Vec<String>,
},
Idle,
}

#[derive(Debug)]
pub struct SegmentController {
current: Option<Segment>,
pending_function_call_ids: BTreeSet<String>,
event_output: ConversationOutput,
event_mapper: fn(Segment) -> serde_json::Value,
}

impl SegmentController {
pub fn new(output: ConversationOutput, event_mapper: fn(Segment) -> serde_json::Value) -> Self {
Self {
current: Some(Segment::Idle),
pending_function_call_ids: BTreeSet::new(),
event_output: output,
event_mapper,
}
}

pub fn begin_user_speech(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::UserSpeaking)
}

pub fn begin_processing(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::Processing)
}

pub fn begin_assistant_speech(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::AssistantSpeaking)
}

pub fn begin_function_wait(&mut self, call_ids: Vec<String>) -> Result<Option<Segment>> {
self.pending_function_call_ids.extend(call_ids);

if self.pending_function_call_ids.is_empty() {
return Ok(None);
}

self.transition(Segment::WaitingForFunctionResult {
call_ids: self.pending_function_call_ids.iter().cloned().collect(),
})
}

pub fn end_function_wait(&mut self, call_id: &str) -> Result<Option<Segment>> {
self.pending_function_call_ids.remove(call_id);

if self.pending_function_call_ids.is_empty() {
self.transition(Segment::Idle)
} else {
Ok(None)
}
}

pub fn become_idle(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::Idle)
}

pub fn is_idle(&self) -> bool {
self.current == Some(Segment::Idle)
}

fn transition(&mut self, segment: Segment) -> Result<Option<Segment>> {
if self.current.as_ref() == Some(&segment) {
return Ok(None);
}

self.current = Some(segment.clone());
self.emit_segment_started(segment.clone())?;
Ok(Some(segment))
}

fn emit_segment_started(&self, segment: Segment) -> Result<()> {
self.event_output
.service_event(SEGMENT_OUTPUT_PATH, (self.event_mapper)(segment))?;
Ok(())
}
}
4 changes: 4 additions & 0 deletions examples/dialog_providers/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ impl ProviderApi for GoogleProvider {

fn parse_service_event(&self, value: serde_json::Value) -> Result<Option<FunctionCall>> {
match serde_json::from_value(value)? {
ServiceOutputEvent::SegmentStarted { segment } => {
tracing::info!(?segment, "Gemini segment started");
Ok(None)
}
ServiceOutputEvent::FunctionCall {
name,
call_id,
Expand Down
4 changes: 4 additions & 0 deletions examples/dialog_providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ impl ProviderApi for OpenAIProvider {

fn parse_service_event(&self, value: serde_json::Value) -> Result<Option<FunctionCall>> {
match serde_json::from_value(value)? {
OpenAIServiceOutputEvent::SegmentStarted { segment } => {
tracing::info!(?segment, "OpenAI segment started");
Ok(None)
}
OpenAIServiceOutputEvent::FunctionCall {
name,
call_id,
Expand Down
52 changes: 48 additions & 4 deletions services/google-dialog/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ use context_switch_core::{
ConversationInput, ConversationOutput, Input, OutputPath,
};

const INPUT_VAD_PEAK_THRESHOLD: u16 = 900;

#[derive(Debug)]
pub struct Client {
params: Params,
Expand All @@ -37,7 +39,7 @@ impl Client {
output: ConversationOutput,
) -> Result<()> {
let billing_scope = self.params.model.clone();
let mut state = ConversationState::new();
let mut state = ConversationState::new(output.clone());
let mut session = match Session::connect(session_config(&self.params, text_outputs)?).await
{
Ok(session) => session,
Expand Down Expand Up @@ -90,6 +92,11 @@ impl Client {
match input {
Input::Audio { frame } => {
let mono = frame.into_mono();
if state.segment_controller.is_idle()
&& has_voice_activity(&mono.samples)
{
state.segment_controller.begin_user_speech()?;
}
let sample_rate = mono.format.sample_rate;
let audio = mono.to_le_bytes();
session
Expand All @@ -104,12 +111,17 @@ impl Client {
.context("Sending text to Gemini Live")?;
}
Input::ServiceEvent { value } => match serde_json::from_value(value)? {
ServiceInputEvent::FunctionCallResult { call_id, output } => {
ServiceInputEvent::FunctionCallResult {
call_id,
output: function_output,
} => {
let Some(name) = state.tool_calls.resolve(&call_id)? else {
return Ok(());
};

let response = normalize_function_response(output);
state.segment_controller.end_function_wait(&call_id)?;

let response = normalize_function_response(function_output);

let response = FunctionResponse {
id: call_id,
Expand Down Expand Up @@ -154,20 +166,32 @@ impl Client {
debug!(%text, "Gemini model text");
}
ServerEvent::ModelAudio(audio) => {
if !state.suppress_assistant_until_turn_complete {
state.segment_controller.begin_assistant_speech()?;
}

let frame = AudioFrame::from_le_bytes(output_format, &audio);
output.audio_frame(frame)?;
}
ServerEvent::GenerationComplete => {}
ServerEvent::TurnComplete => {
self.finalize_output_transcription(text_outputs, output, state)?;
state.segment_controller.become_idle()?;
state.suppress_assistant_until_turn_complete = false;
output.request_completed(None)?;
}
ServerEvent::Interrupted => {
// We expect a TurnComplete afterwards, so don't finalize the output transcription
// when interrupted.
state.segment_controller.begin_user_speech()?;
state.suppress_assistant_until_turn_complete = true;
output.clear_audio()?;
}
ServerEvent::InputTranscription(text) => {
if state.segment_controller.is_idle() {
state.segment_controller.begin_user_speech()?;
}

if self.params.input_audio_transcription {
if text_outputs.text {
output.text(true, text, None, None)?;
Expand All @@ -183,6 +207,9 @@ impl Client {
}
ServerEvent::OutputTranscription(text) => {
if self.params.output_audio_transcription {
if !state.suppress_assistant_until_turn_complete {
state.segment_controller.begin_assistant_speech()?;
}
state.output_transcription_buffer.push_str(&text);
if text_outputs.interim {
output.text(
Expand All @@ -202,6 +229,9 @@ impl Client {
}
}
ServerEvent::ToolCall(calls) => {
self.finalize_output_transcription(text_outputs, output, state)?;

let mut call_ids = Vec::new();
for call in calls {
// Send the function call via the media path.
//
Expand All @@ -220,8 +250,11 @@ impl Client {
},
)?;

call_ids.push(call.id.clone());
state.tool_calls.register(call.id, call.name)?;
}

state.segment_controller.begin_function_wait(call_ids)?;
}
ServerEvent::ToolCallCancellation(ids) => {
// Since we are sending function calls through the media path, we need to send
Expand All @@ -234,7 +267,9 @@ impl Client {
},
)?;

state.tool_calls.cancel(id)?;
state.tool_calls.cancel(id.clone())?;

state.segment_controller.end_function_wait(&id)?;
}
}
ServerEvent::SessionResumption { .. } => {}
Expand Down Expand Up @@ -268,6 +303,9 @@ impl Client {
let buffer = mem::take(&mut state.output_transcription_buffer);

if self.params.output_audio_transcription && text_outputs.text && !buffer.is_empty() {
if !state.suppress_assistant_until_turn_complete {
state.segment_controller.begin_assistant_speech()?;
}
output.text(true, buffer, None, Some(AI_ASSISTANT_SPEAKER.into()))?;
}
Ok(())
Expand All @@ -282,6 +320,12 @@ fn normalize_function_response(output: serde_json::Value) -> serde_json::Value {
}
}

fn has_voice_activity(samples: &[i16]) -> bool {
samples
.iter()
.any(|sample| sample.unsigned_abs() >= INPUT_VAD_PEAK_THRESHOLD)
}

fn session_config(params: &Params, text_outputs: TextOutputs) -> Result<SessionConfig> {
let transport = TransportConfig {
endpoint: params
Expand Down
12 changes: 11 additions & 1 deletion services/google-dialog/src/conversation_state.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use std::collections::hash_map;

use anyhow::{Result, bail};
use context_switch_core::{ConversationOutput, SegmentController};
use tracing::warn;

use crate::ServiceOutputEvent;

#[derive(Debug)]
pub struct ConversationState {
pub output_transcription_buffer: String,
pub segment_controller: SegmentController,
pub suppress_assistant_until_turn_complete: bool,
pub tool_calls: ToolCallTracker,
}

Expand All @@ -21,9 +26,14 @@ enum ToolCallEntry {
}

impl ConversationState {
pub fn new() -> Self {
pub fn new(output: ConversationOutput) -> Self {
Self {
output_transcription_buffer: String::new(),
segment_controller: SegmentController::new(
output,
ServiceOutputEvent::segment_started_json,
),
suppress_assistant_until_turn_complete: false,
tool_calls: ToolCallTracker::default(),
}
}
Expand Down
10 changes: 10 additions & 0 deletions services/google-dialog/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use gemini_live::types::{FunctionDeclaration, RealtimeInputConfig, ThinkingLevel, Tool};
use context_switch_core::Segment;
use serde::{Deserialize, Deserializer, Serialize};

use anyhow::{Result, bail};
Expand Down Expand Up @@ -176,6 +177,8 @@ pub enum ServiceInputEvent {
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "camelCase")]
pub enum ServiceOutputEvent {
#[serde(rename_all = "camelCase")]
SegmentStarted { segment: Segment },
#[serde(rename_all = "camelCase")]
FunctionCall {
call_id: String,
Expand All @@ -186,6 +189,13 @@ pub enum ServiceOutputEvent {
ToolCallCancellation { call_id: String },
}

impl ServiceOutputEvent {
pub fn segment_started_json(segment: Segment) -> serde_json::Value {
serde_json::to_value(Self::SegmentStarted { segment })
.expect("SegmentStarted serialization must succeed")
}
}

#[cfg(test)]
mod tests {
use super::Params;
Expand Down
Loading
Loading