Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions mlx_audio/tts/models/voxtral_tts/voxtral_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,29 +802,18 @@ def _encode_text(self, text: str, voice: str) -> list:
"Expected [NEXT_AUDIO_TEXT] and [REPEAT_AUDIO_TEXT] from tekken.json."
)

try:
from mistral_common.protocol.speech.request import SpeechRequest
if not hasattr(self.tokenizer, "encode_speech_request"):
raise RuntimeError(
"Voxtral TTS requires mistral-common[audio] so the Mistral "
"speech tokenizer can build the prompt correctly. Install the "
"`tts` extra or add `mistral-common[audio]` to your environment."
)

req = SpeechRequest(input=text, voice=voice)
result = self.tokenizer.encode_speech_request(req)
return result.tokens
except (ImportError, AttributeError):
pass
from mistral_common.protocol.speech.request import SpeechRequest

text_tokens = self._encode_text_tokens(text)
n_voice_frames = self._voice_num_audio_tokens.get(voice)
if n_voice_frames is None:
voice_emb = self._get_voice_embedding(voice)
n_voice_frames = voice_emb.shape[0] if voice_emb is not None else 0

return (
[self.config.bos_token_id, self.config.begin_audio_token_id]
+ [self.config.audio_token_id] * n_voice_frames
+ [self._text_to_audio_token_id]
+ text_tokens
+ [self._audio_to_text_token_id]
+ [self.config.begin_audio_token_id]
)
req = SpeechRequest(input=text, voice=voice)
result = self.tokenizer.encode_speech_request(req)
return result.tokens

def _codes_to_global_indices(self, codes: mx.array) -> mx.array:
"""Convert per-codebook codes to global embedding table indices.
Expand Down
39 changes: 39 additions & 0 deletions mlx_audio/tts/tests/test_voxtral_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest
from pathlib import Path

try:
import tomllib
except ModuleNotFoundError: # Python < 3.11
import tomli as tomllib


class _DummyConfig:
bos_token_id = 1
begin_audio_token_id = 25
audio_token_id = 24


class _DummyTokenizer:
pass


class TestVoxtralDependencyContract(unittest.TestCase):
def test_tts_extra_includes_mistral_common_audio(self):
pyproject_path = Path(__file__).resolve().parents[3] / "pyproject.toml"
pyproject = tomllib.loads(pyproject_path.read_text())

tts_extra = pyproject["project"]["optional-dependencies"]["tts"]
self.assertIn("mistral-common[audio]", tts_extra)

def test_encode_text_requires_speech_tokenizer_support(self):
from mlx_audio.tts.models.voxtral_tts.voxtral_tts import Model

model = Model.__new__(Model)
model.config = _DummyConfig()
model.tokenizer = _DummyTokenizer()
model._voice_embeddings = {}
model._text_to_audio_token_id = 100
model._audio_to_text_token_id = 101

with self.assertRaisesRegex(RuntimeError, "mistral-common\\[audio\\]"):
model._encode_text("hello world", "casual_male")
81 changes: 62 additions & 19 deletions mlx_audio/tts/tests/test_voxtral_tts_prompt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
import tempfile
import unittest
from pathlib import Path
from types import SimpleNamespace
from types import ModuleType, SimpleNamespace
from unittest.mock import patch

import mlx.core as mx
Expand All @@ -10,10 +11,48 @@


class FakeTokenizer:
def __init__(self, tokens=None):
self.tokens = tokens or [201, 202, 203]
self.requests = []

def encode(self, text, add_special_tokens=False):
assert add_special_tokens is False
return [101, 102]

def encode_speech_request(self, request):
self.requests.append(request)
return SimpleNamespace(tokens=list(self.tokens))


def patch_fake_speech_request():
request_module = ModuleType("mistral_common.protocol.speech.request")

class FakeSpeechRequest:
def __init__(self, input, voice):
self.input = input
self.voice = voice

request_module.SpeechRequest = FakeSpeechRequest

speech_module = ModuleType("mistral_common.protocol.speech")
speech_module.request = request_module

protocol_module = ModuleType("mistral_common.protocol")
protocol_module.speech = speech_module

mistral_common_module = ModuleType("mistral_common")
mistral_common_module.protocol = protocol_module

return patch.dict(
sys.modules,
{
"mistral_common": mistral_common_module,
"mistral_common.protocol": protocol_module,
"mistral_common.protocol.speech": speech_module,
"mistral_common.protocol.speech.request": request_module,
},
)


class TestVoxtralTTSPrompt(unittest.TestCase):
def _make_model(self):
Expand All @@ -31,27 +70,28 @@ def _make_model(self):
model._audio_to_text_token_id = 35
return model

def test_encode_text_fallback_matches_mistral_common_layout(self):
def test_encode_text_uses_speech_request_tokens(self):
model = self._make_model()
model._voice_embeddings = {"casual_male": mx.zeros((147, 3072))}
model._voice_num_audio_tokens = {"casual_male": 147}
model.tokenizer = FakeTokenizer(tokens=[7, 8, 9])

tokens = Model._encode_text(model, "Hello world.", "casual_male")
with patch_fake_speech_request():
tokens = Model._encode_text(model, "Hello world.", "casual_male")

self.assertEqual(tokens[:3], [1, 25, 24])
self.assertEqual(tokens[1 + 1 + 147], 36)
self.assertEqual(tokens[1 + 1 + 147 + 1 : 1 + 1 + 147 + 3], [101, 102])
self.assertEqual(tokens[-2:], [35, 25])
self.assertEqual(tokens, [7, 8, 9])

def test_encode_text_falls_back_to_voice_embedding_length(self):
def test_encode_text_passes_text_and_voice_to_speech_request(self):
model = self._make_model()
model._voice_embeddings = {"casual_male": mx.zeros((3, 3072))}

tokens = Model._encode_text(model, "Hello world.", "casual_male")
with patch_fake_speech_request():
tokens = Model._encode_text(model, "Hello world.", "casual_male")

self.assertEqual(tokens, [1, 25, 24, 24, 24, 36, 101, 102, 35, 25])
self.assertEqual(tokens, [201, 202, 203])
self.assertEqual(len(model.tokenizer.requests), 1)
request = model.tokenizer.requests[0]
self.assertEqual(request.input, "Hello world.")
self.assertEqual(request.voice, "casual_male")

def test_encode_text_lazy_loads_requested_voice_embedding(self):
def test_encode_text_does_not_lazy_load_voice_embedding_when_supported(self):
model = self._make_model()
with tempfile.TemporaryDirectory() as tmpdir:
voice_file = Path(tmpdir) / "casual_male.safetensors"
Expand All @@ -60,13 +100,16 @@ def test_encode_text_lazy_loads_requested_voice_embedding(self):

with patch(
"mlx_audio.tts.models.voxtral_tts.voxtral_tts.mx.load",
return_value={"embedding": mx.zeros((3, 3072))},
side_effect=AssertionError(
"speech-request path should not load voices"
),
) as mock_load:
tokens = Model._encode_text(model, "Hello world.", "casual_male")
with patch_fake_speech_request():
tokens = Model._encode_text(model, "Hello world.", "casual_male")

self.assertEqual(tokens, [1, 25, 24, 24, 24, 36, 101, 102, 35, 25])
mock_load.assert_called_once_with(str(voice_file))
self.assertIn("casual_male", model._voice_embeddings)
self.assertEqual(tokens, [201, 202, 203])
mock_load.assert_not_called()
self.assertNotIn("casual_male", model._voice_embeddings)

def test_get_voice_embedding_loads_once_and_caches(self):
model = self._make_model()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ stt = [

# TTS dependencies
tts = [
"mistral-common[audio]",
"tiktoken>=0.9.0",
"pydub>=0.25.1",
"sentencepiece>=0.2.0",
Expand Down