diff --git a/mlx_audio/audio_io.py b/mlx_audio/audio_io.py index b4ac45f7a..81d62d3bc 100644 --- a/mlx_audio/audio_io.py +++ b/mlx_audio/audio_io.py @@ -2,8 +2,8 @@ This module provides functions for reading and writing audio files. - Reading: Uses miniaudio to support WAV, MP3, FLAC, and Vorbis formats. - Uses ffmpeg for M4A/AAC format support. -- Writing: Uses miniaudio for WAV and ffmpeg for MP3, FLAC, OGG, Opus, and Vorbis encoding. + Uses ffmpeg for M4A/AAC, OGG, Opus, and WebM format support. +- Writing: Uses miniaudio for WAV and ffmpeg for MP3, FLAC, OGG, Opus, Vorbis, and WebM encoding. """ import io @@ -23,6 +23,7 @@ "vorbis": "vorbis", "m4a": "m4a", "aac": "m4a", + "webm": "webm", } # Sample format mapping @@ -46,6 +47,9 @@ def _detect_format_from_bytes(data: bytes) -> str: elif data[4:8] == b"ftyp": # M4A/MP4/AAC container format return "m4a" + elif data[:4] == b"\x1a\x45\xdf\xa3": + # WebM/Matroska container (EBML header) + return "webm" else: raise ValueError("Unable to detect audio format from bytes") @@ -72,7 +76,7 @@ def _decode_ffmpeg( " ffmpeg not found!\n" "========================================\n" "\n" - "ffmpeg is required for M4A/AAC audio decoding.\n" + "ffmpeg is required for M4A/AAC/WebM audio decoding.\n" "\n" "Install ffmpeg:\n" " macOS: brew install ffmpeg\n" @@ -182,7 +186,7 @@ def read( always_2d: bool = False, dtype: str = "float64", ) -> Tuple[np.ndarray, int]: - """Read an audio file using miniaudio (or ffmpeg for M4A/AAC). + """Read an audio file using miniaudio (or ffmpeg for M4A/AAC/OGG/Opus/WebM). Args: file: Path to the audio file or a BytesIO object. @@ -197,13 +201,17 @@ def read( use_ffmpeg = False if isinstance(file, (str, Path)): ext = Path(file).suffix.lstrip(".").lower() - if ext in ("m4a", "aac", "ogg", "opus"): + if ext in ("m4a", "aac", "ogg", "opus", "webm"): use_ffmpeg = True elif isinstance(file, io.BytesIO): file.seek(0) header = file.read(12) file.seek(0) - if header[4:8] == b"ftyp" or header[:4] == b"OggS": + if ( + header[4:8] == b"ftyp" + or header[:4] == b"OggS" + or header[:4] == b"\x1a\x45\xdf\xa3" + ): use_ffmpeg = True if use_ffmpeg: @@ -297,7 +305,7 @@ def _get_ffmpeg_path() -> str: " ffmpeg not found!\n" "========================================\n" "\n" - "ffmpeg is required for MP3/FLAC encoding and M4A/AAC decoding.\n" + "ffmpeg is required for MP3/FLAC/WebM encoding and M4A/AAC/WebM decoding.\n" "\n" "Install ffmpeg:\n" " macOS: brew install ffmpeg\n" @@ -353,6 +361,8 @@ def _encode_ffmpeg( cmd.extend(["-b:a", bitrate]) elif format == "opus": cmd.extend(["-c:a", "libopus", "-b:a", bitrate]) + elif format == "webm": + cmd.extend(["-c:a", "libopus", "-b:a", bitrate]) elif format in ("ogg", "vorbis"): # Use FLAC codec in OGG container for maximum compatibility # Native vorbis encoder has limitations (experimental, stereo-only) @@ -400,12 +410,12 @@ def write( data: Audio data as numpy array. Shape can be (samples,) for mono or (samples, channels) for multi-channel. samplerate: Sample rate in Hz. - format: Output format. Supports 'wav', 'flac', 'mp3', 'ogg', 'opus', 'vorbis'. + format: Output format. Supports 'wav', 'flac', 'mp3', 'ogg', 'opus', 'vorbis', 'webm'. If None, inferred from file extension. Note: WAV uses miniaudio for encoding. - MP3, FLAC, OGG, Opus, and Vorbis use ffmpeg (must be installed: brew install ffmpeg). + MP3, FLAC, OGG, Opus, Vorbis, and WebM use ffmpeg (must be installed: brew install ffmpeg). """ import miniaudio @@ -488,7 +498,7 @@ def write( else: miniaudio.wav_write_file(str(file), sound) - elif format in ("flac", "mp3", "ogg", "opus", "vorbis"): + elif format in ("flac", "mp3", "ogg", "opus", "vorbis", "webm"): # Check for ffmpeg early to provide a clear error message if not _check_ffmpeg_available(): import warnings diff --git a/mlx_audio/tests/test_audio_io.py b/mlx_audio/tests/test_audio_io.py index b9638bbf7..2e5851ae9 100644 --- a/mlx_audio/tests/test_audio_io.py +++ b/mlx_audio/tests/test_audio_io.py @@ -141,6 +141,36 @@ def test_write_read_vorbis(self, sample_audio_mono, tmp_path): tolerance = max(data.shape[0] * 0.2, samplerate * 0.5) assert abs(read_data.shape[0] - data.shape[0]) < tolerance + @pytest.mark.skipif(not FFMPEG_AVAILABLE, reason="ffmpeg not installed") + def test_write_read_webm(self, sample_audio_mono, tmp_path): + """Test writing and reading WebM file.""" + data, samplerate = sample_audio_mono + output_file = tmp_path / "test.webm" + + write(output_file, data, samplerate, format="webm") + assert output_file.exists() + assert output_file.stat().st_size > 0 + + # Verify we can read it back via ffmpeg + # Note: WebM with Opus internally uses 48kHz, so reading may return different sample rate + read_data, read_samplerate = read(output_file) + assert read_data.shape[0] > 0 # Just verify we got data + + @pytest.mark.skipif(not FFMPEG_AVAILABLE, reason="ffmpeg not installed") + def test_write_read_webm_stereo(self, sample_audio_stereo, tmp_path): + """Test writing and reading stereo WebM file.""" + data, samplerate = sample_audio_stereo + output_file = tmp_path / "test_stereo.webm" + + write(output_file, data, samplerate, format="webm") + assert output_file.exists() + assert output_file.stat().st_size > 0 + + read_data, read_samplerate = read(output_file) + assert read_data.shape[0] > 0 + # WebM/Opus may change channel count, just verify data is returned + assert read_data.ndim >= 1 + @pytest.mark.skipif(not FFMPEG_AVAILABLE, reason="ffmpeg not installed") def test_write_bytesio_ogg(self, sample_audio_mono): """Test writing OGG to BytesIO.""" @@ -169,6 +199,20 @@ def test_write_bytesio_opus(self, sample_audio_stereo): read_data, read_samplerate = read(buffer) assert read_data.shape[0] > 0 # Just verify we got data + @pytest.mark.skipif(not FFMPEG_AVAILABLE, reason="ffmpeg not installed") + def test_write_bytesio_webm(self, sample_audio_mono): + """Test writing WebM to BytesIO and reading it back (simulates browser blob).""" + data, samplerate = sample_audio_mono + buffer = io.BytesIO() + + write(buffer, data, samplerate, format="webm") + assert buffer.getvalue() # Should have content + + # Verify we can read it back (this is the browser blob path) + buffer.seek(0) + read_data, read_samplerate = read(buffer) + assert read_data.shape[0] > 0 # Just verify we got data + @pytest.mark.skipif(not FFMPEG_AVAILABLE, reason="ffmpeg not installed") def test_format_inference_from_extension(self, sample_audio_mono, tmp_path): """Test format inference from file extension."""