|
19 | 19 | import subprocess |
20 | 20 | import os |
21 | 21 | import tempfile |
| 22 | +import wave |
22 | 23 |
|
23 | 24 | __author__ = "absadiki" |
24 | 25 | __copyright__ = "Copyright 2023, " |
@@ -281,12 +282,32 @@ def _load_audio(media_file_path: str) -> np.array: |
281 | 282 | """ |
282 | 283 |
|
283 | 284 | def wav_to_np(file_path): |
284 | | - with open(file_path, 'rb') as f: |
285 | | - f.read(44) |
286 | | - raw_data = f.read() |
287 | | - samples = np.frombuffer(raw_data, dtype=np.int16) |
288 | | - audio_array = samples.astype(np.float32) / np.iinfo(np.int16).max |
289 | | - return audio_array |
| 285 | + with wave.open(file_path, 'rb') as wf: |
| 286 | + num_channels = wf.getnchannels() |
| 287 | + sample_width = wf.getsampwidth() |
| 288 | + sample_rate = wf.getframerate() |
| 289 | + num_frames = wf.getnframes() |
| 290 | + |
| 291 | + if num_channels not in (1, 2): |
| 292 | + raise Exception(f"WAV file must be mono or stereo") |
| 293 | + |
| 294 | + if sample_rate != pw.WHISPER_SAMPLE_RATE: |
| 295 | + raise Exception(f"WAV file must be {pw.WHISPER_SAMPLE_RATE} Hz") |
| 296 | + |
| 297 | + if sample_width != 2: |
| 298 | + raise Exception(f"WAV file must be 16-bit") |
| 299 | + |
| 300 | + raw = wf.readframes(num_frames) |
| 301 | + wf.close() |
| 302 | + audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) |
| 303 | + n = num_frames |
| 304 | + if num_channels == 1: |
| 305 | + pcmf32 = audio / 32768.0 |
| 306 | + else: |
| 307 | + audio = audio.reshape(-1, 2) |
| 308 | + # Averaging the two channels |
| 309 | + pcmf32 = (audio[:, 0] + audio[:, 1]) / 65536.0 |
| 310 | + return pcmf32 |
290 | 311 |
|
291 | 312 | if media_file_path.endswith('.wav'): |
292 | 313 | return wav_to_np(media_file_path) |
|
0 commit comments