fix(kokoro): support quantized checkpoint layout and guard NaN durations#624
Conversation
Quantized Kokoro checkpoints (8bit/6bit/4bit) store conv-like weights in MLX layout, while bf16 checkpoints use PyTorch layout. The sanitize() functions were unconditionally transposing these weights, breaking quantized models with shape mismatches. - Detect packed quantized checkpoints via .scales/.biases key suffixes - Skip 3D conv transposition for already-converted quantized weights - Guard duration path with nan_to_num and cap max frames per phoneme - Return silence instead of crashing on empty concatenation Fixes Blaizzy#623
628af1d to
30885aa
Compare
There was a problem hiding this comment.
Pull request overview
This PR fixes Kokoro model compatibility and stability issues when running inference with quantized checkpoints, and hardens the duration expansion path against NaN/inf-driven crashes.
Changes:
- Update Kokoro weight sanitization to detect packed quantized checkpoints (
.scales/.biases) and avoid double-transposing conv-like 3D weights already in MLX layout. - Guard duration computation against NaN/inf and cap per-phoneme frame expansion to prevent runaway repeats/OOM.
- Document Kokoro quantized model variants in the README.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| README.md | Adds links/examples for bf16 and quantized Kokoro variants (8bit/6bit/4bit). |
| mlx_audio/tts/models/kokoro/kokoro.py | Adds NaN/inf guards + capped duration expansion; updates sanitize() to handle packed quantized checkpoints. |
| mlx_audio/tts/models/kokoro/istftnet.py | Extends decoder sanitize() to skip transposes for packed quantized checkpoints. |
Comments suppressed due to low confidence (1)
mlx_audio/tts/models/kokoro/kokoro.py:252
- Packed-quantized checkpoint detection (
.scales/.biases) changes the sanitize behavior for conv-like weights (skipping the transpose). Please add a unit test forModel.sanitize()that includes representative 3D weights plus.scales/.biaseskeys and asserts the 3D tensors are not transposed in the quantized case, while the bf16/non-quantized case still transposes as before.
has_packed_quantized_weights = any(
key.endswith(".scales") or key.endswith(".biases") for key in weights
)
for key, state_dict in weights.items():
if key.startswith("bert"):
if "position_ids" in key:
# Remove unused position_ids
continue
else:
# print(k, v.shape)
sanitized_weights[key] = state_dict
if key.startswith("bert_encoder"):
sanitized_weights[key] = state_dict
if key.startswith("text_encoder"):
if key.endswith((".gamma", ".beta")):
base_key = key.rsplit(".", 1)[0]
if key.endswith(".gamma"):
new_key = f"{base_key}.weight"
else:
new_key = f"{base_key}.bias"
sanitized_weights[new_key] = state_dict
elif "weight_v" in key:
if has_packed_quantized_weights:
sanitized_weights[key] = state_dict
elif check_array_shape(state_dict):
sanitized_weights[key] = state_dict
else:
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
# Replace weight_ih_l0_reverse and weight_hh_l0_reverse with Wx and Wh
elif key.endswith(
(
".weight_ih_l0_reverse",
".weight_hh_l0_reverse",
".bias_ih_l0_reverse",
".bias_hh_l0_reverse",
".weight_ih_l0",
".weight_hh_l0",
".bias_ih_l0",
".bias_hh_l0",
)
):
sanitized_weights.update(sanitize_lstm_weights(key, state_dict))
else:
sanitized_weights[key] = state_dict
if key.startswith("predictor"):
if "F0_proj.weight" in key:
sanitized_weights[key] = (
state_dict
if has_packed_quantized_weights
else state_dict.transpose(0, 2, 1)
)
elif "N_proj.weight" in key:
sanitized_weights[key] = (
state_dict
if has_packed_quantized_weights
else state_dict.transpose(0, 2, 1)
)
elif "weight_v" in key:
if has_packed_quantized_weights:
sanitized_weights[key] = state_dict
elif check_array_shape(state_dict):
sanitized_weights[key] = state_dict
else:
sanitized_weights[key] = state_dict.transpose(0, 2, 1)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Blaizzy
left a comment
There was a problem hiding this comment.
LGTM, thanks!
One note @beshkenadze could you please not call Github copilot to review, because it polutes the view and creates noise.
You can review locally and then push here so we can review :)
|
@Blaizzy can you please check these settings in this repo? |
|
I think there is a misunderstanding, I'm asking not to use copilot. ❌ My issue is that copilot suggestions are mostly bad |
|
@Blaizzy I meant to say that I didn't request a Copilot review 🫣 |
|
Cool, then I will check what's causing it 👌🏽 |

Summary
8bit/6bit/4bit) — shape mismatch insanitize()Root cause
Quantized Kokoro checkpoints store conv-like weights (
weight_v,F0_proj,N_proj,noise_convs) in MLX layout, while bf16 checkpoints use the PyTorch layout expected by the old sanitize path. The current sanitize logic transposes both, which breaks quantized checkpoints with shape mismatches like:Fix
.scales/.biasestensor keysnan_to_numand cap expansion to 100 frames per phonemeRelated