1+ import sys
12import tempfile
23import unittest
34from pathlib import Path
4- from types import SimpleNamespace
5+ from types import ModuleType , SimpleNamespace
56from unittest .mock import patch
67
78import mlx .core as mx
1011
1112
1213class FakeTokenizer :
14+ def __init__ (self , tokens = None ):
15+ self .tokens = tokens or [201 , 202 , 203 ]
16+ self .requests = []
17+
1318 def encode (self , text , add_special_tokens = False ):
1419 assert add_special_tokens is False
1520 return [101 , 102 ]
1621
22+ def encode_speech_request (self , request ):
23+ self .requests .append (request )
24+ return SimpleNamespace (tokens = list (self .tokens ))
25+
26+
27+ def patch_fake_speech_request ():
28+ request_module = ModuleType ("mistral_common.protocol.speech.request" )
29+
30+ class FakeSpeechRequest :
31+ def __init__ (self , input , voice ):
32+ self .input = input
33+ self .voice = voice
34+
35+ request_module .SpeechRequest = FakeSpeechRequest
36+
37+ speech_module = ModuleType ("mistral_common.protocol.speech" )
38+ speech_module .request = request_module
39+
40+ protocol_module = ModuleType ("mistral_common.protocol" )
41+ protocol_module .speech = speech_module
42+
43+ mistral_common_module = ModuleType ("mistral_common" )
44+ mistral_common_module .protocol = protocol_module
45+
46+ return patch .dict (
47+ sys .modules ,
48+ {
49+ "mistral_common" : mistral_common_module ,
50+ "mistral_common.protocol" : protocol_module ,
51+ "mistral_common.protocol.speech" : speech_module ,
52+ "mistral_common.protocol.speech.request" : request_module ,
53+ },
54+ )
55+
1756
1857class TestVoxtralTTSPrompt (unittest .TestCase ):
1958 def _make_model (self ):
@@ -31,27 +70,28 @@ def _make_model(self):
3170 model ._audio_to_text_token_id = 35
3271 return model
3372
34- def test_encode_text_fallback_matches_mistral_common_layout (self ):
73+ def test_encode_text_uses_speech_request_tokens (self ):
3574 model = self ._make_model ()
36- model ._voice_embeddings = {"casual_male" : mx .zeros ((147 , 3072 ))}
37- model ._voice_num_audio_tokens = {"casual_male" : 147 }
75+ model .tokenizer = FakeTokenizer (tokens = [7 , 8 , 9 ])
3876
39- tokens = Model ._encode_text (model , "Hello world." , "casual_male" )
77+ with patch_fake_speech_request ():
78+ tokens = Model ._encode_text (model , "Hello world." , "casual_male" )
4079
41- self .assertEqual (tokens [:3 ], [1 , 25 , 24 ])
42- self .assertEqual (tokens [1 + 1 + 147 ], 36 )
43- self .assertEqual (tokens [1 + 1 + 147 + 1 : 1 + 1 + 147 + 3 ], [101 , 102 ])
44- self .assertEqual (tokens [- 2 :], [35 , 25 ])
80+ self .assertEqual (tokens , [7 , 8 , 9 ])
4581
46- def test_encode_text_falls_back_to_voice_embedding_length (self ):
82+ def test_encode_text_passes_text_and_voice_to_speech_request (self ):
4783 model = self ._make_model ()
48- model ._voice_embeddings = {"casual_male" : mx .zeros ((3 , 3072 ))}
4984
50- tokens = Model ._encode_text (model , "Hello world." , "casual_male" )
85+ with patch_fake_speech_request ():
86+ tokens = Model ._encode_text (model , "Hello world." , "casual_male" )
5187
52- self .assertEqual (tokens , [1 , 25 , 24 , 24 , 24 , 36 , 101 , 102 , 35 , 25 ])
88+ self .assertEqual (tokens , [201 , 202 , 203 ])
89+ self .assertEqual (len (model .tokenizer .requests ), 1 )
90+ request = model .tokenizer .requests [0 ]
91+ self .assertEqual (request .input , "Hello world." )
92+ self .assertEqual (request .voice , "casual_male" )
5393
54- def test_encode_text_lazy_loads_requested_voice_embedding (self ):
94+ def test_encode_text_does_not_lazy_load_voice_embedding_when_supported (self ):
5595 model = self ._make_model ()
5696 with tempfile .TemporaryDirectory () as tmpdir :
5797 voice_file = Path (tmpdir ) / "casual_male.safetensors"
@@ -60,13 +100,16 @@ def test_encode_text_lazy_loads_requested_voice_embedding(self):
60100
61101 with patch (
62102 "mlx_audio.tts.models.voxtral_tts.voxtral_tts.mx.load" ,
63- return_value = {"embedding" : mx .zeros ((3 , 3072 ))},
103+ side_effect = AssertionError (
104+ "speech-request path should not load voices"
105+ ),
64106 ) as mock_load :
65- tokens = Model ._encode_text (model , "Hello world." , "casual_male" )
107+ with patch_fake_speech_request ():
108+ tokens = Model ._encode_text (model , "Hello world." , "casual_male" )
66109
67- self .assertEqual (tokens , [1 , 25 , 24 , 24 , 24 , 36 , 101 , 102 , 35 , 25 ])
68- mock_load .assert_called_once_with ( str ( voice_file ) )
69- self .assertIn ("casual_male" , model ._voice_embeddings )
110+ self .assertEqual (tokens , [201 , 202 , 203 ])
111+ mock_load .assert_not_called ( )
112+ self .assertNotIn ("casual_male" , model ._voice_embeddings )
70113
71114 def test_get_voice_embedding_loads_once_and_caches (self ):
72115 model = self ._make_model ()
0 commit comments