Skip to content

Commit b94b3a9

Browse files
Support custom_configuration param in ASR clients (#94)
* Passing AST param through custom_configuration * Added exception handling for TTS talk.py * Exposing custom-configurtion to cli * Updating function name to add_custom_configuration_to_config * Updating help message --------- Co-authored-by: mohnishparmar <109233781+mohnishparmar@users.noreply.github.com>
1 parent c789e98 commit b94b3a9

8 files changed

Lines changed: 44 additions & 2 deletions

File tree

riva/client/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
print_streaming,
1313
sleep_audio_length,
1414
add_endpoint_parameters_to_config,
15+
add_custom_configuration_to_config,
1516
)
1617
from riva.client.auth import Auth
1718
from riva.client.nlp import (

riva/client/argparse_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def add_asr_config_argparse_parameters(
8585
type=float,
8686
help="Threshold value for likelihood of blanks before detecting end of utterance",
8787
)
88+
parser.add_argument(
89+
"--custom-configuration",
90+
default="",
91+
type=str,
92+
help="Custom configurations to be sent to the server as key value pairs <key:value,key:value,...>",
93+
)
8894
return parser
8995

9096

riva/client/asr.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ def add_speaker_diarization_to_config(
123123
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
124124
inner_config.diarization_config.CopyFrom(diarization_config)
125125

126+
126127
def add_endpoint_parameters_to_config(
127-
config: Union[rasr.RecognitionConfig, rasr.EndpointingConfig],
128+
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
128129
start_history: int,
129130
start_threshold: float,
130131
stop_history: int,
@@ -152,6 +153,22 @@ def add_endpoint_parameters_to_config(
152153
inner_config.endpointing_config.CopyFrom(endpointing_config)
153154

154155

156+
def add_custom_configuration_to_config(
157+
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
158+
custom_configuration: str,
159+
) -> None:
160+
custom_configuration = custom_configuration.strip().replace(" ", "")
161+
if not custom_configuration:
162+
return
163+
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
164+
for pair in custom_configuration.split(","):
165+
key_value = pair.split(":")
166+
if len(key_value) == 2:
167+
inner_config.custom_configuration[key_value[0]] = key_value[1]
168+
else:
169+
raise ValueError(f"Invalid key:value pair {key_value}")
170+
171+
155172
PRINT_STREAMING_ADDITIONAL_INFO_MODES = ['no', 'time', 'confidence']
156173

157174

scripts/asr/riva_streaming_asr_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def streaming_transcription_worker(
7373
args.stop_threshold,
7474
args.stop_threshold_eou
7575
)
76+
riva.client.add_custom_configuration_to_config(
77+
config,
78+
args.custom_configuration
79+
)
7680
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
7781
for _ in range(args.num_iterations):
7882
with riva.client.AudioChunkFileIterator(

scripts/asr/transcribe_file.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ def main() -> None:
109109
args.stop_threshold,
110110
args.stop_threshold_eou
111111
)
112+
riva.client.add_custom_configuration_to_config(
113+
config,
114+
args.custom_configuration
115+
)
112116
sound_callback = None
113117
try:
114118
if args.play_audio or args.output_device is not None:

scripts/asr/transcribe_file_offline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ def main() -> None:
4646
args.stop_history_eou,
4747
args.stop_threshold,
4848
args.stop_threshold_eou
49-
)
49+
)
50+
riva.client.add_custom_configuration_to_config(
51+
config,
52+
args.custom_configuration
53+
)
5054
with args.input_file.open('rb') as fh:
5155
data = fh.read()
5256
try:

scripts/asr/transcribe_mic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def main() -> None:
6767
args.stop_threshold,
6868
args.stop_threshold_eou
6969
)
70+
riva.client.add_custom_configuration_to_config(
71+
config,
72+
args.custom_configuration
73+
)
7074
with riva.client.audio_io.MicrophoneStream(
7175
args.sample_rate_hz,
7276
args.file_streaming_chunk,

scripts/tts/talk.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ def main() -> None:
157157
sound_stream(resp.audio)
158158
if out_f is not None:
159159
out_f.writeframesraw(resp.audio)
160+
except Exception as e:
161+
print(e.details())
160162
finally:
161163
if out_f is not None:
162164
out_f.close()

0 commit comments

Comments
 (0)