|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -# pyright: reportMissingImports=false |
4 | | - |
5 | | -import os |
6 | 3 | import inspect |
| 4 | +import os |
7 | 5 | import unittest |
8 | 6 | from types import SimpleNamespace |
9 | 7 | from typing import Any, Dict, List, Optional, cast |
|
15 | 13 |
|
16 | 14 | from mlx_audio.codec.models.moss_audio_tokenizer import MossAudioTokenizer |
17 | 15 | from mlx_audio.tts.models.moss_tts.config import ModelConfig |
18 | | -from mlx_audio.utils import apply_quantization |
19 | 16 | from mlx_audio.tts.models.moss_tts.moss_tts import ( |
20 | 17 | Model, |
21 | 18 | _apply_experimental_quantization, |
22 | 19 | _apply_mixed_precision_rescue, |
| 20 | + _encode_reference_audio_with_codec, |
23 | 21 | _get_experimental_quant_mode, |
24 | 22 | _get_experimental_quant_patterns, |
25 | | - _get_mixed_precision_rescue_patterns, |
26 | | - _requantize_module_with_mode, |
27 | | - _encode_reference_audio_with_codec, |
28 | 23 | _get_generated_audio_history, |
| 24 | + _get_mixed_precision_rescue_patterns, |
29 | 25 | _normalize_reference_audio_for_codec, |
30 | 26 | _path_matches_rescue_pattern, |
| 27 | + _requantize_module_with_mode, |
31 | 28 | _suppress_token_ids, |
32 | 29 | find_last_equal, |
33 | 30 | ) |
34 | | -from mlx_audio.tts.models.moss_tts.qwen3 import Qwen3Attention |
35 | 31 | from mlx_audio.tts.models.moss_tts.processor import ( |
36 | 32 | AUDIO_PLACEHOLDER, |
37 | | - MossTTSProcessor, |
38 | 33 | AssistantMessage, |
| 34 | + MossTTSProcessor, |
| 35 | +) |
| 36 | +from mlx_audio.tts.models.moss_tts.processor import ( |
39 | 37 | apply_de_delay_pattern as processor_apply_de_delay_pattern, |
| 38 | +) |
| 39 | +from mlx_audio.tts.models.moss_tts.processor import ( |
40 | 40 | apply_delay_pattern as processor_apply_delay_pattern, |
| 41 | +) |
| 42 | +from mlx_audio.tts.models.moss_tts.processor import ( |
41 | 43 | build_user_message, |
42 | 44 | parse_output, |
43 | 45 | prepare_generation_input, |
44 | 46 | ) |
| 47 | +from mlx_audio.tts.models.moss_tts.qwen3 import Qwen3Attention |
| 48 | +from mlx_audio.utils import apply_quantization |
| 49 | + |
| 50 | +# pyright: reportMissingImports=false |
45 | 51 |
|
46 | 52 |
|
47 | 53 | class TestConfig(unittest.TestCase): |
@@ -931,20 +937,20 @@ def test_generate_smoke_if_local_weights_exist(self): |
931 | 937 | if not os.path.isdir(model_dir): |
932 | 938 | self.skipTest(f"model dir not found: {model_dir} (set MOSS_TTS_MODEL_DIR)") |
933 | 939 |
|
934 | | - from mlx_audio.tts.utils import load_model |
935 | | - |
936 | 940 | from pathlib import Path |
937 | 941 |
|
| 942 | + from mlx_audio.tts.utils import load_model |
| 943 | + |
938 | 944 | model = load_model(Path(model_dir)) |
939 | 945 | assert model.generate is not None |
940 | 946 | results = list(model.generate("Hello from MOSS", max_tokens=64, verbose=False)) |
941 | 947 | self.assertGreaterEqual(len(results), 1) |
942 | 948 |
|
943 | 949 | def test_generate_ref_audio_runs_with_fixture_q8(self): |
944 | | - from mlx_audio.tts.utils import load_model |
945 | | - |
946 | 950 | from pathlib import Path |
947 | 951 |
|
| 952 | + from mlx_audio.tts.utils import load_model |
| 953 | + |
948 | 954 | model_dir = os.environ.get("MOSS_TTS_MODEL_DIR", "./moss-tts-8bit") |
949 | 955 | if not os.path.isdir(model_dir): |
950 | 956 | self.skipTest(f"model dir not found: {model_dir} (set MOSS_TTS_MODEL_DIR)") |
|
0 commit comments