Skip to content

Commit c64890e

Browse files
authored
Merge pull request #633 from lyonsno/fix/voxtral-tts-tokenizer-contract-pr
Fix Voxtral TTS tokenizer dependency contract
2 parents 834d03b + 67db766 commit c64890e

4 files changed

Lines changed: 112 additions & 40 deletions

File tree

mlx_audio/tts/models/voxtral_tts/voxtral_tts.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -802,29 +802,18 @@ def _encode_text(self, text: str, voice: str) -> list:
802802
"Expected [NEXT_AUDIO_TEXT] and [REPEAT_AUDIO_TEXT] from tekken.json."
803803
)
804804

805-
try:
806-
from mistral_common.protocol.speech.request import SpeechRequest
805+
if not hasattr(self.tokenizer, "encode_speech_request"):
806+
raise RuntimeError(
807+
"Voxtral TTS requires mistral-common[audio] so the Mistral "
808+
"speech tokenizer can build the prompt correctly. Install the "
809+
"`tts` extra or add `mistral-common[audio]` to your environment."
810+
)
807811

808-
req = SpeechRequest(input=text, voice=voice)
809-
result = self.tokenizer.encode_speech_request(req)
810-
return result.tokens
811-
except (ImportError, AttributeError):
812-
pass
812+
from mistral_common.protocol.speech.request import SpeechRequest
813813

814-
text_tokens = self._encode_text_tokens(text)
815-
n_voice_frames = self._voice_num_audio_tokens.get(voice)
816-
if n_voice_frames is None:
817-
voice_emb = self._get_voice_embedding(voice)
818-
n_voice_frames = voice_emb.shape[0] if voice_emb is not None else 0
819-
820-
return (
821-
[self.config.bos_token_id, self.config.begin_audio_token_id]
822-
+ [self.config.audio_token_id] * n_voice_frames
823-
+ [self._text_to_audio_token_id]
824-
+ text_tokens
825-
+ [self._audio_to_text_token_id]
826-
+ [self.config.begin_audio_token_id]
827-
)
814+
req = SpeechRequest(input=text, voice=voice)
815+
result = self.tokenizer.encode_speech_request(req)
816+
return result.tokens
828817

829818
def _codes_to_global_indices(self, codes: mx.array) -> mx.array:
830819
"""Convert per-codebook codes to global embedding table indices.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
try:
5+
import tomllib
6+
except ModuleNotFoundError: # Python < 3.11
7+
import tomli as tomllib
8+
9+
10+
class _DummyConfig:
11+
bos_token_id = 1
12+
begin_audio_token_id = 25
13+
audio_token_id = 24
14+
15+
16+
class _DummyTokenizer:
17+
pass
18+
19+
20+
class TestVoxtralDependencyContract(unittest.TestCase):
21+
def test_tts_extra_includes_mistral_common_audio(self):
22+
pyproject_path = Path(__file__).resolve().parents[3] / "pyproject.toml"
23+
pyproject = tomllib.loads(pyproject_path.read_text())
24+
25+
tts_extra = pyproject["project"]["optional-dependencies"]["tts"]
26+
self.assertIn("mistral-common[audio]", tts_extra)
27+
28+
def test_encode_text_requires_speech_tokenizer_support(self):
29+
from mlx_audio.tts.models.voxtral_tts.voxtral_tts import Model
30+
31+
model = Model.__new__(Model)
32+
model.config = _DummyConfig()
33+
model.tokenizer = _DummyTokenizer()
34+
model._voice_embeddings = {}
35+
model._text_to_audio_token_id = 100
36+
model._audio_to_text_token_id = 101
37+
38+
with self.assertRaisesRegex(RuntimeError, "mistral-common\\[audio\\]"):
39+
model._encode_text("hello world", "casual_male")

mlx_audio/tts/tests/test_voxtral_tts_prompt.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import sys
12
import tempfile
23
import unittest
34
from pathlib import Path
4-
from types import SimpleNamespace
5+
from types import ModuleType, SimpleNamespace
56
from unittest.mock import patch
67

78
import mlx.core as mx
@@ -10,10 +11,48 @@
1011

1112

1213
class FakeTokenizer:
14+
def __init__(self, tokens=None):
15+
self.tokens = tokens or [201, 202, 203]
16+
self.requests = []
17+
1318
def encode(self, text, add_special_tokens=False):
1419
assert add_special_tokens is False
1520
return [101, 102]
1621

22+
def encode_speech_request(self, request):
23+
self.requests.append(request)
24+
return SimpleNamespace(tokens=list(self.tokens))
25+
26+
27+
def patch_fake_speech_request():
28+
request_module = ModuleType("mistral_common.protocol.speech.request")
29+
30+
class FakeSpeechRequest:
31+
def __init__(self, input, voice):
32+
self.input = input
33+
self.voice = voice
34+
35+
request_module.SpeechRequest = FakeSpeechRequest
36+
37+
speech_module = ModuleType("mistral_common.protocol.speech")
38+
speech_module.request = request_module
39+
40+
protocol_module = ModuleType("mistral_common.protocol")
41+
protocol_module.speech = speech_module
42+
43+
mistral_common_module = ModuleType("mistral_common")
44+
mistral_common_module.protocol = protocol_module
45+
46+
return patch.dict(
47+
sys.modules,
48+
{
49+
"mistral_common": mistral_common_module,
50+
"mistral_common.protocol": protocol_module,
51+
"mistral_common.protocol.speech": speech_module,
52+
"mistral_common.protocol.speech.request": request_module,
53+
},
54+
)
55+
1756

1857
class TestVoxtralTTSPrompt(unittest.TestCase):
1958
def _make_model(self):
@@ -31,27 +70,28 @@ def _make_model(self):
3170
model._audio_to_text_token_id = 35
3271
return model
3372

34-
def test_encode_text_fallback_matches_mistral_common_layout(self):
73+
def test_encode_text_uses_speech_request_tokens(self):
3574
model = self._make_model()
36-
model._voice_embeddings = {"casual_male": mx.zeros((147, 3072))}
37-
model._voice_num_audio_tokens = {"casual_male": 147}
75+
model.tokenizer = FakeTokenizer(tokens=[7, 8, 9])
3876

39-
tokens = Model._encode_text(model, "Hello world.", "casual_male")
77+
with patch_fake_speech_request():
78+
tokens = Model._encode_text(model, "Hello world.", "casual_male")
4079

41-
self.assertEqual(tokens[:3], [1, 25, 24])
42-
self.assertEqual(tokens[1 + 1 + 147], 36)
43-
self.assertEqual(tokens[1 + 1 + 147 + 1 : 1 + 1 + 147 + 3], [101, 102])
44-
self.assertEqual(tokens[-2:], [35, 25])
80+
self.assertEqual(tokens, [7, 8, 9])
4581

46-
def test_encode_text_falls_back_to_voice_embedding_length(self):
82+
def test_encode_text_passes_text_and_voice_to_speech_request(self):
4783
model = self._make_model()
48-
model._voice_embeddings = {"casual_male": mx.zeros((3, 3072))}
4984

50-
tokens = Model._encode_text(model, "Hello world.", "casual_male")
85+
with patch_fake_speech_request():
86+
tokens = Model._encode_text(model, "Hello world.", "casual_male")
5187

52-
self.assertEqual(tokens, [1, 25, 24, 24, 24, 36, 101, 102, 35, 25])
88+
self.assertEqual(tokens, [201, 202, 203])
89+
self.assertEqual(len(model.tokenizer.requests), 1)
90+
request = model.tokenizer.requests[0]
91+
self.assertEqual(request.input, "Hello world.")
92+
self.assertEqual(request.voice, "casual_male")
5393

54-
def test_encode_text_lazy_loads_requested_voice_embedding(self):
94+
def test_encode_text_does_not_lazy_load_voice_embedding_when_supported(self):
5595
model = self._make_model()
5696
with tempfile.TemporaryDirectory() as tmpdir:
5797
voice_file = Path(tmpdir) / "casual_male.safetensors"
@@ -60,13 +100,16 @@ def test_encode_text_lazy_loads_requested_voice_embedding(self):
60100

61101
with patch(
62102
"mlx_audio.tts.models.voxtral_tts.voxtral_tts.mx.load",
63-
return_value={"embedding": mx.zeros((3, 3072))},
103+
side_effect=AssertionError(
104+
"speech-request path should not load voices"
105+
),
64106
) as mock_load:
65-
tokens = Model._encode_text(model, "Hello world.", "casual_male")
107+
with patch_fake_speech_request():
108+
tokens = Model._encode_text(model, "Hello world.", "casual_male")
66109

67-
self.assertEqual(tokens, [1, 25, 24, 24, 24, 36, 101, 102, 35, 25])
68-
mock_load.assert_called_once_with(str(voice_file))
69-
self.assertIn("casual_male", model._voice_embeddings)
110+
self.assertEqual(tokens, [201, 202, 203])
111+
mock_load.assert_not_called()
112+
self.assertNotIn("casual_male", model._voice_embeddings)
70113

71114
def test_get_voice_embedding_loads_once_and_caches(self):
72115
model = self._make_model()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ stt = [
4141

4242
# TTS dependencies
4343
tts = [
44+
"mistral-common[audio]",
4445
"tiktoken>=0.9.0",
4546
"pydub>=0.25.1",
4647
"sentencepiece>=0.2.0",

0 commit comments

Comments
 (0)