Skip to content

fix(kokoro): support quantized checkpoint layout and guard NaN durations#624

Merged
Blaizzy merged 5 commits into
Blaizzy:mainfrom
beshkenadze:fix/kokoro-quantized-checkpoint-layout
Apr 14, 2026
Merged

fix(kokoro): support quantized checkpoint layout and guard NaN durations#624
Blaizzy merged 5 commits into
Blaizzy:mainfrom
beshkenadze:fix/kokoro-quantized-checkpoint-layout

Conversation

@beshkenadze
Copy link
Copy Markdown
Contributor

Summary

  • Fix Kokoro quantized checkpoint loading (8bit/6bit/4bit) — shape mismatch in sanitize()
  • Guard the duration path against NaN-derived crashes

Root cause

Quantized Kokoro checkpoints store conv-like weights (weight_v, F0_proj, N_proj, noise_convs) in MLX layout, while bf16 checkpoints use the PyTorch layout expected by the old sanitize path. The current sanitize logic transposes both, which breaks quantized checkpoints with shape mismatches like:

Expected shape (512, 3, 512) but received shape (512, 512, 3)

Fix

  • Detect packed quantized checkpoints from .scales / .biases tensor keys
  • Skip 3D conv transposition for already-converted quantized checkpoints
  • Guard duration with nan_to_num and cap expansion to 100 frames per phoneme
  • Return silence instead of crashing on empty concatenation

Related

Quantized Kokoro checkpoints (8bit/6bit/4bit) store conv-like weights in
MLX layout, while bf16 checkpoints use PyTorch layout. The sanitize()
functions were unconditionally transposing these weights, breaking
quantized models with shape mismatches.

- Detect packed quantized checkpoints via .scales/.biases key suffixes
- Skip 3D conv transposition for already-converted quantized weights
- Guard duration path with nan_to_num and cap max frames per phoneme
- Return silence instead of crashing on empty concatenation

Fixes Blaizzy#623
@beshkenadze beshkenadze force-pushed the fix/kokoro-quantized-checkpoint-layout branch from 628af1d to 30885aa Compare March 30, 2026 17:54
Copilot AI review requested due to automatic review settings April 2, 2026 12:47
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes Kokoro model compatibility and stability issues when running inference with quantized checkpoints, and hardens the duration expansion path against NaN/inf-driven crashes.

Changes:

  • Update Kokoro weight sanitization to detect packed quantized checkpoints (.scales/.biases) and avoid double-transposing conv-like 3D weights already in MLX layout.
  • Guard duration computation against NaN/inf and cap per-phoneme frame expansion to prevent runaway repeats/OOM.
  • Document Kokoro quantized model variants in the README.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
README.md Adds links/examples for bf16 and quantized Kokoro variants (8bit/6bit/4bit).
mlx_audio/tts/models/kokoro/kokoro.py Adds NaN/inf guards + capped duration expansion; updates sanitize() to handle packed quantized checkpoints.
mlx_audio/tts/models/kokoro/istftnet.py Extends decoder sanitize() to skip transposes for packed quantized checkpoints.
Comments suppressed due to low confidence (1)

mlx_audio/tts/models/kokoro/kokoro.py:252

  • Packed-quantized checkpoint detection (.scales/.biases) changes the sanitize behavior for conv-like weights (skipping the transpose). Please add a unit test for Model.sanitize() that includes representative 3D weights plus .scales/.biases keys and asserts the 3D tensors are not transposed in the quantized case, while the bf16/non-quantized case still transposes as before.
        has_packed_quantized_weights = any(
            key.endswith(".scales") or key.endswith(".biases") for key in weights
        )
        for key, state_dict in weights.items():
            if key.startswith("bert"):
                if "position_ids" in key:
                    # Remove unused position_ids
                    continue
                else:
                    # print(k, v.shape)
                    sanitized_weights[key] = state_dict

            if key.startswith("bert_encoder"):
                sanitized_weights[key] = state_dict

            if key.startswith("text_encoder"):
                if key.endswith((".gamma", ".beta")):
                    base_key = key.rsplit(".", 1)[0]
                    if key.endswith(".gamma"):
                        new_key = f"{base_key}.weight"
                    else:
                        new_key = f"{base_key}.bias"

                    sanitized_weights[new_key] = state_dict
                elif "weight_v" in key:
                    if has_packed_quantized_weights:
                        sanitized_weights[key] = state_dict
                    elif check_array_shape(state_dict):
                        sanitized_weights[key] = state_dict
                    else:
                        sanitized_weights[key] = state_dict.transpose(0, 2, 1)

                # Replace weight_ih_l0_reverse and weight_hh_l0_reverse with Wx and Wh
                elif key.endswith(
                    (
                        ".weight_ih_l0_reverse",
                        ".weight_hh_l0_reverse",
                        ".bias_ih_l0_reverse",
                        ".bias_hh_l0_reverse",
                        ".weight_ih_l0",
                        ".weight_hh_l0",
                        ".bias_ih_l0",
                        ".bias_hh_l0",
                    )
                ):
                    sanitized_weights.update(sanitize_lstm_weights(key, state_dict))
                else:
                    sanitized_weights[key] = state_dict

            if key.startswith("predictor"):
                if "F0_proj.weight" in key:
                    sanitized_weights[key] = (
                        state_dict
                        if has_packed_quantized_weights
                        else state_dict.transpose(0, 2, 1)
                    )

                elif "N_proj.weight" in key:
                    sanitized_weights[key] = (
                        state_dict
                        if has_packed_quantized_weights
                        else state_dict.transpose(0, 2, 1)
                    )

                elif "weight_v" in key:
                    if has_packed_quantized_weights:
                        sanitized_weights[key] = state_dict
                    elif check_array_shape(state_dict):
                        sanitized_weights[key] = state_dict
                    else:
                        sanitized_weights[key] = state_dict.transpose(0, 2, 1)


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlx_audio/tts/models/kokoro/kokoro.py
Comment thread mlx_audio/tts/models/kokoro/kokoro.py
Comment thread mlx_audio/tts/models/kokoro/istftnet.py
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

One note @beshkenadze could you please not call Github copilot to review, because it polutes the view and creates noise.

You can review locally and then push here so we can review :)

@beshkenadze
Copy link
Copy Markdown
Contributor Author

@Blaizzy can you please check these settings in this repo?
image

@Blaizzy
Copy link
Copy Markdown
Owner

Blaizzy commented Apr 14, 2026

I think there is a misunderstanding, I'm asking not to use copilot. ❌

My issue is that copilot suggestions are mostly bad

@beshkenadze
Copy link
Copy Markdown
Contributor Author

@Blaizzy I meant to say that I didn't request a Copilot review 🫣

@Blaizzy
Copy link
Copy Markdown
Owner

Blaizzy commented Apr 14, 2026

Cool, then I will check what's causing it 👌🏽

@Blaizzy Blaizzy merged commit 0de1561 into Blaizzy:main Apr 14, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Kokoro: quantized checkpoints (8bit/6bit/4bit) fail with shape mismatch in sanitize()

3 participants