Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ for result in model.generate("Hello from MLX-Audio!", voice="af_heart"):

| Model | Description | Languages | Repo |
|-------|-------------|-----------|------|
| **Kokoro** | Fast, high-quality multilingual TTS | EN, JA, ZH, FR, ES, IT, PT, HI | [mlx-community/Kokoro-82M-bf16](https://huggingface.co/mlx-community/Kokoro-82M-bf16) |
| **Kokoro** | Fast, high-quality multilingual TTS | EN, JA, ZH, FR, ES, IT, PT, HI | [bf16](https://huggingface.co/mlx-community/Kokoro-82M-bf16), [8bit](https://huggingface.co/mlx-community/Kokoro-82M-8bit), [6bit](https://huggingface.co/mlx-community/Kokoro-82M-6bit), [4bit](https://huggingface.co/mlx-community/Kokoro-82M-4bit) |
| **Qwen3-TTS** | Alibaba's multilingual TTS with voice design | ZH, EN, JA, KO, + more | [mlx-community/Qwen3-TTS-12Hz-1.7B-VoiceDesign-bf16](https://huggingface.co/mlx-community/Qwen3-TTS-12Hz-1.7B-VoiceDesign-bf16) |
| **CSM** | Conversational Speech Model with voice cloning | EN | [mlx-community/csm-1b](https://huggingface.co/mlx-community/csm-1b) |
| **Dia** | Dialogue-focused TTS | EN | [mlx-community/Dia-1.6B-fp16](https://huggingface.co/mlx-community/Dia-1.6B-fp16) |
Expand Down Expand Up @@ -153,6 +153,9 @@ Kokoro is a fast, multilingual TTS model with 54 voice presets.
from mlx_audio.tts.utils import load_model

model = load_model("mlx-community/Kokoro-82M-bf16")
# Or use a quantized variant for lower memory usage:
# model = load_model("mlx-community/Kokoro-82M-8bit")
# model = load_model("mlx-community/Kokoro-82M-4bit")

# Generate with different voices
for result in model.generate(
Expand Down
20 changes: 9 additions & 11 deletions mlx_audio/tts/models/kokoro/istftnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,18 +962,16 @@ def __call__(self, asr, F0_curve, N, s):
x = self.generator(x, s, F0_curve) # Working in MLX
return x

def sanitize(self, key, weights):
sanitized_weights = None
def sanitize(self, key, weights, has_packed_quantized_weights=False):
if has_packed_quantized_weights:
return weights

if "noise_convs" in key and key.endswith(".weight"):
sanitized_weights = weights.transpose(0, 2, 1)
return weights.transpose(0, 2, 1)

elif "weight_v" in key:
if "weight_v" in key:
if check_array_shape(weights):
sanitized_weights = weights
else:
sanitized_weights = weights.transpose(0, 2, 1)

else:
sanitized_weights = weights
return weights
return weights.transpose(0, 2, 1)

return sanitized_weights
return weights
Comment thread
beshkenadze marked this conversation as resolved.
53 changes: 43 additions & 10 deletions mlx_audio/tts/models/kokoro/kokoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,27 @@ def __call__(
d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = self.predictor.lstm(d)
duration = self.predictor.duration_proj(x)
max_frames_per_phoneme = 100
duration = mx.sigmoid(duration).sum(axis=-1) / speed
pred_dur = mx.clip(mx.round(duration), a_min=1, a_max=None).astype(mx.int32)[0]
indices = mx.concatenate(
[mx.repeat(mx.array(i), int(n)) for i, n in enumerate(pred_dur)]
duration = mx.nan_to_num(
duration, nan=1.0, posinf=max_frames_per_phoneme, neginf=1.0
)
pred_dur = mx.clip(
mx.round(duration), a_min=1, a_max=max_frames_per_phoneme
Comment thread
beshkenadze marked this conversation as resolved.
).astype(mx.int32)[0]
indices_list = []
for i, n in enumerate(pred_dur):
count = min(max(int(n), 0), max_frames_per_phoneme)
if count > 0:
indices_list.append(mx.repeat(mx.array(i), count))
if not indices_list:
silence = mx.zeros((1, 1))
return (
self.Output(audio=silence, pred_dur=pred_dur)
if return_output
else silence
)
indices = mx.concatenate(indices_list)
Comment thread
beshkenadze marked this conversation as resolved.
pred_aln_trg = mx.zeros((input_ids.shape[1], indices.shape[0]))
pred_aln_trg[indices, mx.arange(indices.shape[0])] = 1
pred_aln_trg = pred_aln_trg[None, :]
Expand All @@ -162,8 +178,10 @@ def __call__(

def sanitize(self, weights):
sanitized_weights = {}
has_packed_quantized_weights = any(
key.endswith(".scales") or key.endswith(".biases") for key in weights
)
for key, state_dict in weights.items():

if key.startswith("bert"):
if "position_ids" in key:
# Remove unused position_ids
Expand All @@ -176,7 +194,6 @@ def sanitize(self, weights):
sanitized_weights[key] = state_dict

if key.startswith("text_encoder"):

if key.endswith((".gamma", ".beta")):
base_key = key.rsplit(".", 1)[0]
if key.endswith(".gamma"):
Expand All @@ -186,7 +203,9 @@ def sanitize(self, weights):

sanitized_weights[new_key] = state_dict
elif "weight_v" in key:
if check_array_shape(state_dict):
if has_packed_quantized_weights:
sanitized_weights[key] = state_dict
elif check_array_shape(state_dict):
sanitized_weights[key] = state_dict
else:
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
Expand All @@ -210,13 +229,23 @@ def sanitize(self, weights):

if key.startswith("predictor"):
if "F0_proj.weight" in key:
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
sanitized_weights[key] = (
state_dict
if has_packed_quantized_weights
else state_dict.transpose(0, 2, 1)
)

elif "N_proj.weight" in key:
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
sanitized_weights[key] = (
state_dict
if has_packed_quantized_weights
else state_dict.transpose(0, 2, 1)
)

elif "weight_v" in key:
if check_array_shape(state_dict):
if has_packed_quantized_weights:
sanitized_weights[key] = state_dict
elif check_array_shape(state_dict):
sanitized_weights[key] = state_dict
else:
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
Expand All @@ -239,7 +268,11 @@ def sanitize(self, weights):
sanitized_weights[key] = state_dict

if key.startswith("decoder"):
sanitized_weights[key] = self.decoder.sanitize(key, state_dict)
sanitized_weights[key] = self.decoder.sanitize(
key,
state_dict,
has_packed_quantized_weights=has_packed_quantized_weights,
)
return sanitized_weights

@property
Expand Down