1212
1313from grpc ._channel import _MultiThreadedRendezvous
1414
15+ import riva .client
1516import riva .client .proto .riva_asr_pb2 as rasr
1617import riva .client .proto .riva_asr_pb2_grpc as rasr_srv
1718from riva .client .auth import Auth
1819
1920
2021def get_wav_file_parameters (input_file : Union [str , os .PathLike ]) -> Dict [str , Union [int , float ]]:
21- input_file = Path (input_file ).expanduser ()
22- with wave .open (str (input_file ), 'rb' ) as wf :
23- nframes = wf .getnframes ()
24- rate = wf .getframerate ()
25- parameters = {
26- 'nframes' : nframes ,
27- 'framerate' : rate ,
28- 'duration' : nframes / rate ,
29- 'nchannels' : wf .getnchannels (),
30- 'sampwidth' : wf .getsampwidth (),
31- }
22+ try :
23+ input_file = Path (input_file ).expanduser ()
24+ with wave .open (str (input_file ), 'rb' ) as wf :
25+ nframes = wf .getnframes ()
26+ rate = wf .getframerate ()
27+ parameters = {
28+ 'nframes' : nframes ,
29+ 'framerate' : rate ,
30+ 'duration' : nframes / rate ,
31+ 'nchannels' : wf .getnchannels (),
32+ 'sampwidth' : wf .getsampwidth (),
33+ 'data_offset' : wf .getfp ().size_read + wf .getfp ().offset
34+ }
35+ except :
36+ # Not a WAV file
37+ return None
3238 return parameters
3339
3440
@@ -47,7 +53,11 @@ def __init__(
4753 self .chunk_n_frames = chunk_n_frames
4854 self .delay_callback = delay_callback
4955 self .file_parameters = get_wav_file_parameters (self .input_file )
50- self .file_object : Optional [wave .Wave_read ] = wave .open (str (self .input_file ), 'rb' )
56+ self .file_object : Optional [typing .BinaryIO ] = open (str (self .input_file ), 'rb' )
57+ if self .delay_callback and self .file_parameters is None :
58+ warnings .warn (f"delay_callback not supported for encoding other than LINEAR_PCM" )
59+ self .delay_callback = None
60+ self .first_buffer = True
5161
5262 def close (self ) -> None :
5363 self .file_object .close ()
@@ -64,15 +74,19 @@ def __iter__(self):
6474 return self
6575
6676 def __next__ (self ) -> bytes :
67- data = self .file_object .readframes (self .chunk_n_frames )
77+ if self .file_parameters :
78+ data = self .file_object .read (self .chunk_n_frames * self .file_parameters ['sampwidth' ] * self .file_parameters ['nchannels' ])
79+ else :
80+ data = self .file_object .read (self .chunk_n_frames )
6881 if not data :
6982 self .close ()
7083 raise StopIteration
7184 if self .delay_callback is not None :
85+ offset = self .file_parameters ['data_offset' ] if self .first_buffer else 0
7286 self .delay_callback (
73- data ,
74- len (data ) / self .file_parameters ['sampwidth' ] / self .file_parameters ['framerate' ]
87+ data [offset :], (len (data ) - offset ) / self .file_parameters ['sampwidth' ] / self .file_parameters ['framerate' ]
7588 )
89+ self .first_buffer = False
7690 return data
7791
7892
@@ -95,8 +109,9 @@ def add_audio_file_specs_to_config(
95109) -> None :
96110 inner_config : rasr .RecognitionConfig = config if isinstance (config , rasr .RecognitionConfig ) else config .config
97111 wav_parameters = get_wav_file_parameters (audio_file )
98- inner_config .sample_rate_hertz = wav_parameters ['framerate' ]
99- inner_config .audio_channel_count = wav_parameters ['nchannels' ]
112+ if wav_parameters is not None :
113+ inner_config .sample_rate_hertz = wav_parameters ['framerate' ]
114+ inner_config .audio_channel_count = wav_parameters ['nchannels' ]
100115
101116
102117def add_speaker_diarization_to_config (
0 commit comments