diff --git a/mlx_audio/tts/models/voxtral_tts/voxtral_tts.py b/mlx_audio/tts/models/voxtral_tts/voxtral_tts.py index b8af91946..96e7b1d74 100644 --- a/mlx_audio/tts/models/voxtral_tts/voxtral_tts.py +++ b/mlx_audio/tts/models/voxtral_tts/voxtral_tts.py @@ -802,29 +802,18 @@ def _encode_text(self, text: str, voice: str) -> list: "Expected [NEXT_AUDIO_TEXT] and [REPEAT_AUDIO_TEXT] from tekken.json." ) - try: - from mistral_common.protocol.speech.request import SpeechRequest + if not hasattr(self.tokenizer, "encode_speech_request"): + raise RuntimeError( + "Voxtral TTS requires mistral-common[audio] so the Mistral " + "speech tokenizer can build the prompt correctly. Install the " + "`tts` extra or add `mistral-common[audio]` to your environment." + ) - req = SpeechRequest(input=text, voice=voice) - result = self.tokenizer.encode_speech_request(req) - return result.tokens - except (ImportError, AttributeError): - pass + from mistral_common.protocol.speech.request import SpeechRequest - text_tokens = self._encode_text_tokens(text) - n_voice_frames = self._voice_num_audio_tokens.get(voice) - if n_voice_frames is None: - voice_emb = self._get_voice_embedding(voice) - n_voice_frames = voice_emb.shape[0] if voice_emb is not None else 0 - - return ( - [self.config.bos_token_id, self.config.begin_audio_token_id] - + [self.config.audio_token_id] * n_voice_frames - + [self._text_to_audio_token_id] - + text_tokens - + [self._audio_to_text_token_id] - + [self.config.begin_audio_token_id] - ) + req = SpeechRequest(input=text, voice=voice) + result = self.tokenizer.encode_speech_request(req) + return result.tokens def _codes_to_global_indices(self, codes: mx.array) -> mx.array: """Convert per-codebook codes to global embedding table indices. diff --git a/mlx_audio/tts/tests/test_voxtral_tts.py b/mlx_audio/tts/tests/test_voxtral_tts.py new file mode 100644 index 000000000..c5d972e81 --- /dev/null +++ b/mlx_audio/tts/tests/test_voxtral_tts.py @@ -0,0 +1,39 @@ +import unittest +from pathlib import Path + +try: + import tomllib +except ModuleNotFoundError: # Python < 3.11 + import tomli as tomllib + + +class _DummyConfig: + bos_token_id = 1 + begin_audio_token_id = 25 + audio_token_id = 24 + + +class _DummyTokenizer: + pass + + +class TestVoxtralDependencyContract(unittest.TestCase): + def test_tts_extra_includes_mistral_common_audio(self): + pyproject_path = Path(__file__).resolve().parents[3] / "pyproject.toml" + pyproject = tomllib.loads(pyproject_path.read_text()) + + tts_extra = pyproject["project"]["optional-dependencies"]["tts"] + self.assertIn("mistral-common[audio]", tts_extra) + + def test_encode_text_requires_speech_tokenizer_support(self): + from mlx_audio.tts.models.voxtral_tts.voxtral_tts import Model + + model = Model.__new__(Model) + model.config = _DummyConfig() + model.tokenizer = _DummyTokenizer() + model._voice_embeddings = {} + model._text_to_audio_token_id = 100 + model._audio_to_text_token_id = 101 + + with self.assertRaisesRegex(RuntimeError, "mistral-common\\[audio\\]"): + model._encode_text("hello world", "casual_male") diff --git a/mlx_audio/tts/tests/test_voxtral_tts_prompt.py b/mlx_audio/tts/tests/test_voxtral_tts_prompt.py index d3dd5d8e6..f0d6b7c41 100644 --- a/mlx_audio/tts/tests/test_voxtral_tts_prompt.py +++ b/mlx_audio/tts/tests/test_voxtral_tts_prompt.py @@ -1,7 +1,8 @@ +import sys import tempfile import unittest from pathlib import Path -from types import SimpleNamespace +from types import ModuleType, SimpleNamespace from unittest.mock import patch import mlx.core as mx @@ -10,10 +11,48 @@ class FakeTokenizer: + def __init__(self, tokens=None): + self.tokens = tokens or [201, 202, 203] + self.requests = [] + def encode(self, text, add_special_tokens=False): assert add_special_tokens is False return [101, 102] + def encode_speech_request(self, request): + self.requests.append(request) + return SimpleNamespace(tokens=list(self.tokens)) + + +def patch_fake_speech_request(): + request_module = ModuleType("mistral_common.protocol.speech.request") + + class FakeSpeechRequest: + def __init__(self, input, voice): + self.input = input + self.voice = voice + + request_module.SpeechRequest = FakeSpeechRequest + + speech_module = ModuleType("mistral_common.protocol.speech") + speech_module.request = request_module + + protocol_module = ModuleType("mistral_common.protocol") + protocol_module.speech = speech_module + + mistral_common_module = ModuleType("mistral_common") + mistral_common_module.protocol = protocol_module + + return patch.dict( + sys.modules, + { + "mistral_common": mistral_common_module, + "mistral_common.protocol": protocol_module, + "mistral_common.protocol.speech": speech_module, + "mistral_common.protocol.speech.request": request_module, + }, + ) + class TestVoxtralTTSPrompt(unittest.TestCase): def _make_model(self): @@ -31,27 +70,28 @@ def _make_model(self): model._audio_to_text_token_id = 35 return model - def test_encode_text_fallback_matches_mistral_common_layout(self): + def test_encode_text_uses_speech_request_tokens(self): model = self._make_model() - model._voice_embeddings = {"casual_male": mx.zeros((147, 3072))} - model._voice_num_audio_tokens = {"casual_male": 147} + model.tokenizer = FakeTokenizer(tokens=[7, 8, 9]) - tokens = Model._encode_text(model, "Hello world.", "casual_male") + with patch_fake_speech_request(): + tokens = Model._encode_text(model, "Hello world.", "casual_male") - self.assertEqual(tokens[:3], [1, 25, 24]) - self.assertEqual(tokens[1 + 1 + 147], 36) - self.assertEqual(tokens[1 + 1 + 147 + 1 : 1 + 1 + 147 + 3], [101, 102]) - self.assertEqual(tokens[-2:], [35, 25]) + self.assertEqual(tokens, [7, 8, 9]) - def test_encode_text_falls_back_to_voice_embedding_length(self): + def test_encode_text_passes_text_and_voice_to_speech_request(self): model = self._make_model() - model._voice_embeddings = {"casual_male": mx.zeros((3, 3072))} - tokens = Model._encode_text(model, "Hello world.", "casual_male") + with patch_fake_speech_request(): + tokens = Model._encode_text(model, "Hello world.", "casual_male") - self.assertEqual(tokens, [1, 25, 24, 24, 24, 36, 101, 102, 35, 25]) + self.assertEqual(tokens, [201, 202, 203]) + self.assertEqual(len(model.tokenizer.requests), 1) + request = model.tokenizer.requests[0] + self.assertEqual(request.input, "Hello world.") + self.assertEqual(request.voice, "casual_male") - def test_encode_text_lazy_loads_requested_voice_embedding(self): + def test_encode_text_does_not_lazy_load_voice_embedding_when_supported(self): model = self._make_model() with tempfile.TemporaryDirectory() as tmpdir: voice_file = Path(tmpdir) / "casual_male.safetensors" @@ -60,13 +100,16 @@ def test_encode_text_lazy_loads_requested_voice_embedding(self): with patch( "mlx_audio.tts.models.voxtral_tts.voxtral_tts.mx.load", - return_value={"embedding": mx.zeros((3, 3072))}, + side_effect=AssertionError( + "speech-request path should not load voices" + ), ) as mock_load: - tokens = Model._encode_text(model, "Hello world.", "casual_male") + with patch_fake_speech_request(): + tokens = Model._encode_text(model, "Hello world.", "casual_male") - self.assertEqual(tokens, [1, 25, 24, 24, 24, 36, 101, 102, 35, 25]) - mock_load.assert_called_once_with(str(voice_file)) - self.assertIn("casual_male", model._voice_embeddings) + self.assertEqual(tokens, [201, 202, 203]) + mock_load.assert_not_called() + self.assertNotIn("casual_male", model._voice_embeddings) def test_get_voice_embedding_loads_once_and_caches(self): model = self._make_model() diff --git a/pyproject.toml b/pyproject.toml index 243af8536..9ceed5e47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ stt = [ # TTS dependencies tts = [ + "mistral-common[audio]", "tiktoken>=0.9.0", "pydub>=0.25.1", "sentencepiece>=0.2.0",