Skip to content

Commit 79b8000

Browse files
authored
Add speaker diarization to streaming ASR clients (#116)
* enable speaker diarization for streaming_asr_client * add: speaker diarization to transcribe_file * fix: print confidence when word offsets are disabled
1 parent 340e1e3 commit 79b8000

3 files changed

Lines changed: 27 additions & 6 deletions

File tree

riva/client/asr.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def print_streaming(
183183
word_time_offsets: bool = False,
184184
show_intermediate: bool = False,
185185
file_mode: str = 'w',
186+
speaker_diarization: bool = False,
186187
) -> None:
187188
"""
188189
Prints streaming speech recognition results to provided files or streams.
@@ -284,12 +285,21 @@ def print_streaming(
284285
if word_time_offsets:
285286
for f in output_file:
286287
f.write("Timestamps:\n")
287-
f.write('{: <40s}{: <16s}{: <16s}\n'.format('Word', 'Start (ms)', 'End (ms)'))
288+
temp = '{: <40s}{: <16s}{: <16s}'
289+
value = ['Word', 'Start (ms)', 'End (ms)']
290+
if speaker_diarization:
291+
temp += '{: <16s}'
292+
value.append('Speaker')
293+
temp += '\n'
294+
f.write(temp.format(*value))
288295
for word_info in result.alternatives[0].words:
289296
f.write(
290297
f'{word_info.word: <40s}{word_info.start_time: <16.0f}'
291-
f'{word_info.end_time: <16.0f}\n'
298+
f'{word_info.end_time: <16.0f}'
292299
)
300+
if speaker_diarization:
301+
f.write(f'{word_info.speaker_tag: <16d}')
302+
f.write('\n')
293303
else:
294304
partial_transcript += transcript
295305
else: # additional_info == 'confidence'

scripts/asr/riva_streaming_asr_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def streaming_transcription_worker(
6060
profanity_filter=args.profanity_filter,
6161
enable_automatic_punctuation=args.automatic_punctuation,
6262
verbatim_transcripts=not args.no_verbatim_transcripts,
63-
enable_word_time_offsets=args.word_time_offsets,
63+
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
6464
),
6565
interim_results=True,
6666
)
@@ -78,6 +78,7 @@ def streaming_transcription_worker(
7878
args.custom_configuration
7979
)
8080
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
81+
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization, args.diarization_max_speakers)
8182
for _ in range(args.num_iterations):
8283
with riva.client.AudioChunkFileIterator(
8384
args.input_file,
@@ -92,7 +93,8 @@ def streaming_transcription_worker(
9293
output_file=output_file,
9394
additional_info='time',
9495
file_mode='a',
95-
word_time_offsets=args.word_time_offsets,
96+
word_time_offsets=args.word_time_offsets or args.speaker_diarization,
97+
speaker_diarization=args.speaker_diarization,
9698
)
9799
except BaseException as e:
98100
exception_queue.put((e, thread_i))

scripts/asr/transcribe_file.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: MIT
33

44
import argparse
5+
from pathlib import Path
56

67
import os
78
import riva.client
@@ -50,7 +51,7 @@ def parse_args() -> argparse.Namespace:
5051
"normal speech.",
5152
)
5253
parser.add_argument(
53-
"--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript."
54+
"--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript. If `--word-time-offsets` or `--speaker-diarization` is set, then confidence is not printed."
5455
)
5556
parser = add_connection_argparse_parameters(parser)
5657
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
@@ -88,6 +89,8 @@ def main() -> None:
8889
print(f"Invalid input file path: {args.input_file}")
8990
return
9091

92+
output_file = Path(f"output.txt").expanduser()
93+
9194
config = riva.client.StreamingRecognitionConfig(
9295
config=riva.client.RecognitionConfig(
9396
language_code=args.language_code,
@@ -96,10 +99,12 @@ def main() -> None:
9699
profanity_filter=args.profanity_filter,
97100
enable_automatic_punctuation=args.automatic_punctuation,
98101
verbatim_transcripts=not args.no_verbatim_transcripts,
102+
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
99103
),
100104
interim_results=True,
101105
)
102106
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
107+
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization, args.diarization_max_speakers)
103108
riva.client.add_endpoint_parameters_to_config(
104109
config,
105110
args.start_history,
@@ -131,8 +136,12 @@ def main() -> None:
131136
audio_chunks=audio_chunk_iterator,
132137
streaming_config=config,
133138
),
139+
output_file=output_file,
140+
file_mode='a',
134141
show_intermediate=args.show_intermediate,
135-
additional_info="confidence" if args.print_confidence else "no",
142+
additional_info="time" if (args.word_time_offsets or args.speaker_diarization) else ("confidence" if args.print_confidence else "no"),
143+
word_time_offsets=args.word_time_offsets or args.speaker_diarization,
144+
speaker_diarization=args.speaker_diarization,
136145
)
137146
finally:
138147
if sound_callback is not None and sound_callback.opened:

0 commit comments

Comments
 (0)