@@ -4,18 +4,19 @@ use std::time::Duration;
44
55use anyhow:: { Context , Result , bail} ;
66use clap:: { Parser , ValueEnum } ;
7+ use tokio:: select;
8+ use tokio:: sync:: mpsc:: { channel, unbounded_channel} ;
9+
710use cpal:: traits:: { DeviceTrait , HostTrait , StreamTrait } ;
811use openai_api_rs:: realtime:: types:: {
912 AzureSemanticVadConfig , EndOfUtteranceDetectionConfig , EndOfUtteranceDetectionModel ,
1013 EndOfUtteranceThresholdLevel , TurnDetection ,
1114} ;
1215use rodio:: DeviceSinkBuilder ;
13- use tokio:: select;
14- use tokio:: sync:: mpsc:: { channel, unbounded_channel} ;
1516
1617use context_switch:: services:: {
17- AristechTranscribe , AzureTranscribe , ElevenLabsTranscribe , GoogleTranscribe ,
18- MicrosoftVoiceLiveTranscribe ,
18+ AristechTranscribe , AzureTranscribe , DeepgramTranscribe , ElevenLabsTranscribe ,
19+ GoogleTranscribe , MicrosoftVoiceLiveTranscribe ,
1920} ;
2021use context_switch:: { AudioConsumer , InputModality , OutputModality } ;
2122use context_switch_core:: language:: Languages ;
@@ -51,6 +52,8 @@ enum Provider {
5152 Aristech ,
5253 #[ value( name = "voice-live" ) ]
5354 VoiceLive ,
55+ #[ value( name = "deepgram" ) ]
56+ Deepgram ,
5457}
5558
5659#[ tokio:: main]
@@ -240,11 +243,10 @@ async fn start_conversation(
240243 diarization : bool ,
241244 conversation : Conversation ,
242245) -> Result < ( ) > {
246+ validate_provider_args ( provider, model, region, diarization) ?;
247+
243248 match provider {
244249 Provider :: Azure => {
245- if region. is_some ( ) {
246- bail ! ( "--region is only supported for the google provider" ) ;
247- }
248250 let params = azure:: transcribe:: Params {
249251 endpoint : env:: var ( "AZURE_ENDPOINT" )
250252 . ok ( )
@@ -259,12 +261,6 @@ async fn start_conversation(
259261 AzureTranscribe . conversation ( params, conversation) . await
260262 }
261263 Provider :: Elevenlabs => {
262- if diarization {
263- bail ! ( "--diarization is only supported for the azure provider" ) ;
264- }
265- if region. is_some ( ) {
266- bail ! ( "--region is only supported for the google provider" ) ;
267- }
268264 let language = Some (
269265 languages
270266 . single ( )
@@ -317,12 +313,6 @@ async fn start_conversation(
317313 GoogleTranscribe . conversation ( params, conversation) . await
318314 }
319315 Provider :: Aristech => {
320- if diarization {
321- bail ! ( "--diarization is only supported for the azure provider" ) ;
322- }
323- if region. is_some ( ) {
324- bail ! ( "--region is only supported for the google provider" ) ;
325- }
326316 let language = languages
327317 . single ( )
328318 . context ( "Aristech provider supports exactly one --language value" ) ?
@@ -351,13 +341,6 @@ async fn start_conversation(
351341 AristechTranscribe . conversation ( params, conversation) . await
352342 }
353343 Provider :: VoiceLive => {
354- if diarization {
355- bail ! ( "--diarization is only supported for the azure provider" ) ;
356- }
357- if region. is_some ( ) {
358- bail ! ( "--region is only supported for the google provider" ) ;
359- }
360-
361344 let language = Some (
362345 languages
363346 . single ( )
@@ -383,7 +366,7 @@ async fn start_conversation(
383366 AzureSemanticVadConfig {
384367 end_of_utterance_detection : Some ( EndOfUtteranceDetectionConfig {
385368 model : EndOfUtteranceDetectionModel :: SmartEndOfTurnDetection ,
386- threshold_level : Some ( EndOfUtteranceThresholdLevel :: Low ) ,
369+ threshold_level : Some ( EndOfUtteranceThresholdLevel :: Default ) ,
387370 timeout_ms : Some ( 5000 ) ,
388371 } ) ,
389372 // remove_filler_words: Some(true),
@@ -396,5 +379,91 @@ async fn start_conversation(
396379 . conversation ( params, conversation)
397380 . await
398381 }
382+ Provider :: Deepgram => {
383+ let params = deepgram_service:: transcribe:: Params {
384+ api_key : env:: var ( "DEEPGRAM_API_KEY" ) . expect ( "DEEPGRAM_API_KEY undefined" ) ,
385+ endpoint : env:: var ( "DEEPGRAM_ENDPOINT" ) . expect ( "DEEPGRAM_ENDPOINT undefined" ) ,
386+ language : languages. join_csv ( ) ,
387+ profanity_filter : false ,
388+ keyterm : vec ! [ ] ,
389+ turn_detection : deepgram_service:: transcribe:: TurnDetection :: default ( ) ,
390+ } ;
391+
392+ DeepgramTranscribe . conversation ( params, conversation) . await
393+ }
394+ }
395+ }
396+
397+ #[ derive( Debug , Clone , Copy , Default ) ]
398+ struct ProviderCapabilities {
399+ region : bool ,
400+ diarization : bool ,
401+ model : bool ,
402+ }
403+
404+ impl Provider {
405+ fn capabilities ( self ) -> ProviderCapabilities {
406+ let mut capabilities = ProviderCapabilities :: default ( ) ;
407+
408+ match self {
409+ Provider :: Azure => {
410+ capabilities. diarization = true ;
411+ capabilities. model = true ;
412+ }
413+ Provider :: Deepgram => { }
414+ Provider :: Elevenlabs => {
415+ capabilities. model = true ;
416+ }
417+ Provider :: Google => {
418+ capabilities. region = true ;
419+ capabilities. diarization = true ;
420+ capabilities. model = true ;
421+ }
422+ Provider :: Aristech => {
423+ capabilities. model = true ;
424+ }
425+ Provider :: VoiceLive => {
426+ capabilities. model = true ;
427+ }
428+ }
429+
430+ capabilities
399431 }
400432}
433+
434+ fn validate_capability (
435+ option_name : & str ,
436+ is_used : bool ,
437+ capability : bool ,
438+ provider : Provider ,
439+ ) -> Result < ( ) > {
440+ if !is_used || capability {
441+ return Ok ( ( ) ) ;
442+ }
443+
444+ bail ! (
445+ "{option_name} is unsupported for provider '{}'" ,
446+ provider
447+ . to_possible_value( )
448+ . expect( "Provider has a possible value" )
449+ . get_name( )
450+ )
451+ }
452+
453+ fn validate_provider_args (
454+ provider : Provider ,
455+ model : Option < & str > ,
456+ region : Option < & str > ,
457+ diarization : bool ,
458+ ) -> Result < ( ) > {
459+ let capabilities = provider. capabilities ( ) ;
460+
461+ validate_capability ( "--model" , model. is_some ( ) , capabilities. model , provider) ?;
462+ validate_capability ( "--region" , region. is_some ( ) , capabilities. region , provider) ?;
463+ validate_capability (
464+ "--diarization" ,
465+ diarization,
466+ capabilities. diarization ,
467+ provider,
468+ )
469+ }
0 commit comments