Skip to content

Commit e42e143

Browse files
authored
Add LongCat-AudioDiT 1B TTS model (#627)
* add longcat audio * add readme * bump version * add tests * format * Implement model_quant_predicate method to skip quantization for VAE in longcat_audiodit.py * Update README.md to reflect model name change and audio playback updates - Changed model loading from `meituan-longcat/LongCat-AudioDiT-1B` to `mlx-community/LongCat-AudioDiT-1B-bf16`. - Updated audio playback code to use `AudioPlayer` instead of `sounddevice`. - Enhanced the available models section with new formats and additional model options. * Implement streaming audio generation in longcat_audiodit.py - Added _stream_decode method for chunked audio decoding with cosine crossfade, improving time-to-first-audio. - Updated generate method to support streaming with new parameters: stream, streaming_interval, chunk_seconds, and overlap_seconds. - Introduced _format_duration static method for consistent audio duration formatting. * Update README.md to reflect model name change for LongCat-AudioDiT - Changed model loading reference from `meituan-longcat/LongCat-AudioDiT-1B` to `mlx-community/LongCat-AudioDiT-1B-bf16` for consistency with updated model repository.
1 parent 91caeb9 commit e42e143

11 files changed

Lines changed: 2332 additions & 1 deletion

File tree

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ for result in model.generate("Hello from MLX-Audio!", voice="af_heart"):
103103
| **Ming Omni TTS (Dense)** | Lightweight dense Ming Omni variant for voice cloning and style control | EN, ZH | [mlx-community/Ming-omni-tts-0.5B-bf16](https://huggingface.co/mlx-community/Ming-omni-tts-0.5B-bf16) |
104104
| **KugelAudio** | SOTA 7B AR+Diffusion TTS for European languages | EN, DE, FR, ES, IT, PT, NL, PL, RU, UK, + 14 more | [kugelaudio/kugelaudio-0-open](https://huggingface.co/kugelaudio/kugelaudio-0-open) |
105105
| **Voxtral TTS** | Mistral's 4B multilingual TTS (20 voices, 9 languages) | EN, FR, ES, DE, IT, PT, NL, AR, HI | [mlx-community/Voxtral-4B-TTS-2603-mlx-bf16](https://huggingface.co/mlx-community/Voxtral-4B-TTS-2603-mlx-bf16) |
106+
| **LongCat-AudioDiT** | SOTA diffusion TTS in waveform latent space with voice cloning | ZH, EN | [mlx-community/LongCat-AudioDiT-1B-bf16](https://huggingface.co/mlx-community/LongCat-AudioDiT-1B-bf16) |
106107

107108
### Speech-to-Text (STT)
108109

@@ -392,6 +393,32 @@ python -m mlx_audio.convert \
392393
> **Note:** Requires ~17GB memory (7B params in bfloat16).
393394
> Pre-encoded voice presets (voice cloning) are not yet available in the upstream model — the model generates speech with a default voice.
394395
396+
### LongCat-AudioDiT
397+
398+
SOTA diffusion-based TTS operating in the waveform latent space. Uses Conditional Flow Matching with a DiT backbone and WAV-VAE codec at 24kHz. Supports zero-shot voice cloning.
399+
400+
```python
401+
from mlx_audio.tts.utils import load
402+
403+
model = load("mlx-community/LongCat-AudioDiT-1B-bf16")
404+
405+
# Zero-shot TTS
406+
result = next(model.generate("Hello, this is a test of AudioDiT."))
407+
audio = result.audio # mx.array, 24kHz
408+
409+
# Voice cloning (use "apg" guidance for best similarity)
410+
result = next(model.generate(
411+
text="Today is warm turning to rain.",
412+
ref_audio="reference.wav",
413+
ref_text="Transcript of the reference audio.",
414+
guidance_method="apg",
415+
cfg_strength=4.0,
416+
steps=16,
417+
))
418+
```
419+
420+
See the [LongCat-AudioDiT README](mlx_audio/tts/models/longcat_audiodit/README.md) for all parameters and CLI usage.
421+
395422
### Voxtral TTS
396423

397424
Mistral's 4B multilingual text-to-speech with 20 voice presets across 9 languages.
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# LongCat-AudioDiT
2+
3+
State-of-the-art diffusion-based text-to-speech that operates directly in the waveform latent space. Uses Conditional Flow Matching with a DiT (Diffusion Transformer) backbone and a WAV-VAE audio codec at 24kHz. Supports zero-shot voice cloning with SOTA speaker similarity on the Seed benchmark.
4+
5+
**Paper:** [LongCat-AudioDiT](https://github.com/meituan-longcat/LongCat-AudioDiT/blob/main/LongCat-AudioDiT.pdf)
6+
7+
## Usage
8+
9+
Python API:
10+
11+
```python
12+
from mlx_audio.tts.utils import load
13+
14+
model = load("mlx-community/LongCat-AudioDiT-1B-bf16")
15+
16+
result = next(model.generate("Hello, this is a test of AudioDiT."))
17+
audio = result.audio # mlx array, 24kHz
18+
```
19+
20+
Play audio directly:
21+
22+
```python
23+
from mlx_audio.tts.audio_player import AudioPlayer
24+
25+
player = AudioPlayer(sample_rate=24000)
26+
result = next(model.generate("The quick brown fox jumps over the lazy dog."))
27+
player.queue_audio(result.audio)
28+
player.wait_for_drain()
29+
player.stop()
30+
```
31+
32+
## Voice Cloning
33+
34+
Clone any voice using a reference audio sample and its transcript. Use `guidance_method="apg"` for best voice cloning quality:
35+
36+
```python
37+
result = next(model.generate(
38+
text="Today is warm turning to rain, with good air quality.",
39+
ref_audio="reference.wav",
40+
ref_text="Transcript of the reference audio.",
41+
guidance_method="apg",
42+
cfg_strength=4.0,
43+
steps=16,
44+
))
45+
```
46+
47+
## Zero-Shot Generation (Chinese)
48+
49+
```python
50+
result = next(model.generate(
51+
text="今天晴暖转阴雨,空气质量优至良,空气相对湿度较低。",
52+
steps=16,
53+
cfg_strength=4.0,
54+
))
55+
```
56+
57+
## Generation Parameters
58+
59+
| Parameter | Default | Description |
60+
|-----------|---------|-------------|
61+
| `steps` | 16 | Euler ODE solver steps. Higher = better quality, slower |
62+
| `cfg_strength` | 4.0 | Classifier-free guidance strength |
63+
| `guidance_method` | `"cfg"` | `"cfg"` for TTS, `"apg"` for voice cloning |
64+
| `seed` | 1024 | Random seed for reproducibility |
65+
| `ref_audio` | `None` | Reference audio for voice cloning (24kHz) |
66+
| `ref_text` | `None` | Transcript of the reference audio |
67+
68+
## CLI
69+
70+
```bash
71+
# Zero-shot TTS
72+
python -m mlx_audio.tts.generate \
73+
--model mlx-community/LongCat-AudioDiT-1B-bf16 \
74+
--text "Hello, this is a test of AudioDiT." \
75+
--play
76+
77+
# Voice cloning
78+
python -m mlx_audio.tts.generate \
79+
--model mlx-community/LongCat-AudioDiT-1B-bf16 \
80+
--text "Today is warm turning to rain." \
81+
--ref_audio reference.wav \
82+
--ref_text "Transcript of the reference audio." \
83+
--play
84+
```
85+
86+
## Available Models
87+
88+
| Model | Parameters | Format | Languages |
89+
|-------|-----------|--------|-----------|
90+
| `mlx-community/LongCat-AudioDiT-1B-bf16` | 1B | bf16 | Chinese, English |
91+
| `mlx-community/LongCat-AudioDiT-1B-8bit` | 1B | 8-bit | Chinese, English |
92+
| `mlx-community/LongCat-AudioDiT-1B-6bit` | 1B | 6-bit | Chinese, English |
93+
| `mlx-community/LongCat-AudioDiT-1B-5bit` | 1B | 5-bit | Chinese, English |
94+
| `mlx-community/LongCat-AudioDiT-1B-4bit` | 1B | 4-bit | Chinese, English |
95+
| `mlx-community/LongCat-AudioDiT-1B-mxfp8` | 1B | MXFP8 | Chinese, English |
96+
| `mlx-community/LongCat-AudioDiT-1B-mxfp4` | 1B | MXFP4 | Chinese, English |
97+
| `mlx-community/LongCat-AudioDiT-1B-nvfp4` | 1B | NVFP4 | Chinese, English |
98+
| `mlx-community/LongCat-AudioDiT-3.5B-bf16` | 3.5B | bf16 | Chinese, English |
99+
| `mlx-community/LongCat-AudioDiT-3.5B-8bit` | 3.5B | 8-bit | Chinese, English |
100+
| `mlx-community/LongCat-AudioDiT-3.5B-6bit` | 3.5B | 6-bit | Chinese, English |
101+
| `mlx-community/LongCat-AudioDiT-3.5B-5bit` | 3.5B | 5-bit | Chinese, English |
102+
| `mlx-community/LongCat-AudioDiT-3.5B-4bit` | 3.5B | 4-bit | Chinese, English |
103+
| `mlx-community/LongCat-AudioDiT-3.5B-mxfp8` | 3.5B | MXFP8 | Chinese, English |
104+
| `mlx-community/LongCat-AudioDiT-3.5B-mxfp4` | 3.5B | MXFP4 | Chinese, English |
105+
| `mlx-community/LongCat-AudioDiT-3.5B-nvfp4` | 3.5B | NVFP4 | Chinese, English |
106+
107+
## Architecture
108+
109+
- **DiT backbone:** dim=1536, depth=24, heads=24 with RoPE and AdaLN
110+
- **WAV-VAE codec:** latent_dim=64, 24kHz, runs in fp16
111+
- **UMT5 text encoder:** 768-dim, 12 layers with per-layer relative position bias
112+
- **Conditional Flow Matching** with Euler ODE solver
113+
114+
## License
115+
116+
LongCat-AudioDiT weights and code are released under the [MIT License](https://github.com/meituan-longcat/LongCat-AudioDiT/blob/main/LICENSE).
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .config import ModelConfig
2+
from .longcat_audiodit import Model
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import math
2+
from dataclasses import dataclass, field
3+
from typing import List, Optional
4+
5+
from mlx_audio.tts.models.base import BaseModelArgs
6+
7+
8+
@dataclass
9+
class VaeConfig:
10+
in_channels: int = 1
11+
channels: int = 128
12+
c_mults: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
13+
strides: List[int] = field(default_factory=lambda: [2, 4, 4, 8, 8])
14+
latent_dim: int = 64
15+
encoder_latent_dim: int = 128
16+
use_snake: bool = True
17+
downsample_shortcut: str = "averaging"
18+
upsample_shortcut: str = "duplicating"
19+
out_shortcut: str = "averaging"
20+
in_shortcut: str = "duplicating"
21+
final_tanh: bool = False
22+
downsampling_ratio: int = 2048
23+
sample_rate: int = 24000
24+
scale: float = 0.71
25+
26+
27+
@dataclass
28+
class TextEncoderConfig:
29+
vocab_size: int = 256384
30+
d_model: int = 768
31+
d_kv: int = 64
32+
d_ff: int = 2048
33+
num_layers: int = 12
34+
num_heads: int = 12
35+
relative_attention_num_buckets: int = 32
36+
relative_attention_max_distance: int = 128
37+
dropout_rate: float = 0.1
38+
layer_norm_epsilon: float = 1e-6
39+
is_gated_act: bool = True
40+
dense_act_fn: str = "gelu_new"
41+
42+
43+
@dataclass
44+
class ModelConfig(BaseModelArgs):
45+
model_type: str = "audiodit"
46+
dit_dim: int = 1536
47+
dit_depth: int = 24
48+
dit_heads: int = 24
49+
dit_ff_mult: float = 4.0
50+
dit_text_dim: int = 768
51+
dit_dropout: float = 0.0
52+
dit_bias: bool = True
53+
dit_cross_attn: bool = True
54+
dit_adaln_type: str = "global"
55+
dit_adaln_use_text_cond: bool = True
56+
dit_long_skip: bool = True
57+
dit_text_conv: bool = True
58+
dit_qk_norm: bool = True
59+
dit_cross_attn_norm: bool = False
60+
dit_eps: float = 1e-6
61+
dit_use_latent_condition: bool = True
62+
repa_dit_layer: int = 8
63+
latent_dim: int = 64
64+
sigma: float = 0.0
65+
sampling_rate: int = 24000
66+
latent_hop: int = 2048
67+
max_wav_duration: float = 30.0
68+
text_encoder_model: str = "google/umt5-base"
69+
text_add_embed: bool = True
70+
text_norm_feat: bool = True
71+
vae_config: Optional[VaeConfig] = None
72+
text_encoder_config: Optional[TextEncoderConfig] = None
73+
74+
def __post_init__(self):
75+
if isinstance(self.vae_config, dict):
76+
self.vae_config = VaeConfig(
77+
**{
78+
k: v
79+
for k, v in self.vae_config.items()
80+
if k in VaeConfig.__dataclass_fields__
81+
}
82+
)
83+
if self.vae_config is None:
84+
self.vae_config = VaeConfig()
85+
if isinstance(self.text_encoder_config, dict):
86+
self.text_encoder_config = TextEncoderConfig(
87+
**{
88+
k: v
89+
for k, v in self.text_encoder_config.items()
90+
if k in TextEncoderConfig.__dataclass_fields__
91+
}
92+
)
93+
if self.text_encoder_config is None:
94+
self.text_encoder_config = TextEncoderConfig()

0 commit comments

Comments
 (0)