-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsegment.rs
More file actions
102 lines (83 loc) · 2.93 KB
/
Copy pathsegment.rs
File metadata and controls
102 lines (83 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use std::collections::BTreeSet;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use crate::{ConversationOutput, OutputPath};
const SEGMENT_OUTPUT_PATH: OutputPath = OutputPath::Media;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "phase", rename_all = "camelCase")]
pub enum Segment {
UserSpeaking,
AssistantSpeaking,
Processing,
#[serde(rename_all = "camelCase")]
WaitingForFunctionResult {
call_ids: Vec<String>,
},
Idle,
}
#[derive(Debug)]
pub struct SegmentController {
current: Option<Segment>,
pending_function_call_ids: BTreeSet<String>,
event_output: ConversationOutput,
event_mapper: fn(Segment) -> serde_json::Value,
}
impl SegmentController {
pub fn new(output: ConversationOutput, event_mapper: fn(Segment) -> serde_json::Value) -> Self {
Self {
current: Some(Segment::Idle),
pending_function_call_ids: BTreeSet::new(),
event_output: output,
event_mapper,
}
}
pub fn begin_user_speech(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::UserSpeaking)
}
pub fn begin_processing(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::Processing)
}
pub fn begin_assistant_speech(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::AssistantSpeaking)
}
pub fn begin_function_wait(&mut self, call_ids: Vec<String>) -> Result<Option<Segment>> {
self.pending_function_call_ids.extend(call_ids);
if self.pending_function_call_ids.is_empty() {
return Ok(None);
}
self.transition(Segment::WaitingForFunctionResult {
call_ids: self.pending_function_call_ids.iter().cloned().collect(),
})
}
pub fn end_function_wait(&mut self, call_id: &str) -> Result<Option<Segment>> {
self.pending_function_call_ids.remove(call_id);
if self.pending_function_call_ids.is_empty() {
self.transition(Segment::Idle)
} else {
Ok(None)
}
}
pub fn become_idle(&mut self) -> Result<Option<Segment>> {
self.pending_function_call_ids.clear();
self.transition(Segment::Idle)
}
pub fn is_idle(&self) -> bool {
self.current == Some(Segment::Idle)
}
fn transition(&mut self, segment: Segment) -> Result<Option<Segment>> {
if self.current.as_ref() == Some(&segment) {
return Ok(None);
}
self.current = Some(segment.clone());
self.emit_segment_started(segment.clone())?;
Ok(Some(segment))
}
fn emit_segment_started(&self, segment: Segment) -> Result<()> {
self.event_output
.service_event(SEGMENT_OUTPUT_PATH, (self.event_mapper)(segment))?;
Ok(())
}
}