diff --git a/examples/indextts2_inference.py b/examples/indextts2_inference.py new file mode 100644 index 000000000..1653eaaa1 --- /dev/null +++ b/examples/indextts2_inference.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +"""Example: IndexTTS2 inference with quality-focused settings.""" + +import argparse +from pathlib import Path + +import numpy as np +from mlx_audio.audio_io import sf_write +from mlx_audio.tts import load + + +def main(): + parser = argparse.ArgumentParser(description="Run IndexTTS2 TTS inference") + parser.add_argument( + "--model", + default="indextts2_mlx", + help="Path to converted IndexTTS2 MLX model directory", + ) + parser.add_argument( + "--ref-audio", + default="examples/bible-audiobook/audios/bible-akjv/af_heart/00000001-Genesis-1:1.wav", + help="Reference speaker audio path", + ) + parser.add_argument( + "--text", + default="In the beginning, God created the heavens and the earth.", + help="Text to synthesize", + ) + parser.add_argument( + "--out", + default="indextts2_out.wav", + help="Output WAV path", + ) + parser.add_argument( + "--diffusion-steps", + type=int, + default=50, + help="Higher values are slower but often clearer (40-60 recommended)", + ) + parser.add_argument( + "--diffusion-cfg-rate", + type=float, + default=0.7, + help="Classifier-free guidance rate for s2mel diffusion", + ) + parser.add_argument( + "--repetition-penalty", + type=float, + default=10.0, + help="AR repetition penalty for semantic token decoding", + ) + args = parser.parse_args() + + model = load(Path(args.model), strict=True) + + result = next( + model.generate( + args.text, + ref_audio=args.ref_audio, + diffusion_steps=args.diffusion_steps, + diffusion_cfg_rate=args.diffusion_cfg_rate, + repetition_penalty=args.repetition_penalty, + ) + ) + + audio = np.array(result.audio, dtype=np.float32) + sf_write(args.out, audio, result.sample_rate) + + print(f"Saved: {args.out}") + print(f"Sample rate: {result.sample_rate}") + print(f"Audio duration: {result.audio_duration}") + print(f"RTF: {result.real_time_factor:.4f}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/indextts2/__init__.py b/mlx_audio/tts/indextts2/__init__.py new file mode 100644 index 000000000..74c3c95b5 --- /dev/null +++ b/mlx_audio/tts/indextts2/__init__.py @@ -0,0 +1,19 @@ +from .emotion import ( + CN_TO_EN, + EMO_BIAS, + EMOTION_KEYS, + QwenEmotion, + QwenEmotionConfig, + normalize_emo_vector, + parse_emotion_response, +) + +__all__ = [ + "EMOTION_KEYS", + "CN_TO_EN", + "EMO_BIAS", + "parse_emotion_response", + "normalize_emo_vector", + "QwenEmotionConfig", + "QwenEmotion", +] diff --git a/mlx_audio/tts/indextts2/emotion.py b/mlx_audio/tts/indextts2/emotion.py new file mode 100644 index 000000000..368ef19fe --- /dev/null +++ b/mlx_audio/tts/indextts2/emotion.py @@ -0,0 +1,235 @@ +import json +import re +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + + +EMOTION_KEYS = [ + "happy", + "angry", + "sad", + "afraid", + "disgusted", + "melancholic", + "surprised", + "calm", +] + + +CN_TO_EN = { + "高兴": "happy", + "愤怒": "angry", + "悲伤": "sad", + "恐惧": "afraid", + "反感": "disgusted", + "低落": "melancholic", + "惊讶": "surprised", + "自然": "calm", +} + + +EMO_BIAS = { + # Bias factors from the official IndexTTS2 inference helper. + # Order: [happy, angry, sad, afraid, disgusted, melancholic, surprised, calm] + "happy": 0.9375, + "angry": 0.875, + "sad": 1.0, + "afraid": 1.0, + "disgusted": 0.9375, + "melancholic": 0.9375, + "surprised": 0.6875, + "calm": 0.5625, +} + + +def _clamp(v: float, lo: float, hi: float) -> float: + return max(lo, min(hi, v)) + + +def _coerce_float(x: Any) -> Optional[float]: + try: + return float(x) + except Exception: + return None + + +def parse_emotion_response(text: str) -> Dict[str, float]: + """Parse a model response into an emotion dict. + + Accepts either: + - JSON with English keys + - JSON with Chinese keys (mapped via CN_TO_EN) + - Loose `key: number` pairs in free-form text + """ + + text = text.strip() + + # Try to extract a JSON object substring first (common for chatty outputs). + m = re.search(r"\{[\s\S]*\}", text) + json_blob = m.group(0) if m else None + + candidates = [json_blob, text] if json_blob else [text] + for blob in candidates: + try: + obj = json.loads(blob) + if isinstance(obj, dict): + return _normalize_emotion_dict(obj) + except Exception: + pass + + # Fallback: regex parse key/value pairs. + # Matches: happy: 0.5, "angry":0.2, 高兴: 1.0, etc. + pairs: Dict[str, float] = {} + for key, val in re.findall( + r"([A-Za-z_]+|[\u4e00-\u9fff]+)\s*[:=]\s*([-+]?\d+(?:\.\d+)?)", + text, + ): + f = _coerce_float(val) + if f is None: + continue + pairs[key] = f + + return _normalize_emotion_dict(pairs) + + +def _normalize_emotion_dict(obj: Dict[str, Any]) -> Dict[str, float]: + # Map keys to English and drop unknown keys. + out: Dict[str, float] = {} + for k, v in obj.items(): + if not isinstance(k, str): + continue + key = k.strip() + if key in CN_TO_EN: + key = CN_TO_EN[key] + + if key not in EMOTION_KEYS: + continue + + f = _coerce_float(v) + if f is None: + continue + + out[key] = f + + return out + + +def normalize_emo_vector( + emo: Dict[str, float], + *, + min_score: float = 0.0, + max_score: float = 1.2, + apply_bias: bool = True, + max_sum: float = 0.8, +) -> Tuple[Dict[str, float], list[float]]: + """Clamp + bias + sum-normalize emotion vectors. + + Returns both a dict (by key) and a list in EMOTION_KEYS order. + """ + + vec: Dict[str, float] = {k: 0.0 for k in EMOTION_KEYS} + for k in EMOTION_KEYS: + if k in emo: + vec[k] = _clamp(float(emo[k]), min_score, max_score) + + # Default to neutral/calm if empty. + if all(v <= 0.0 for v in vec.values()): + vec["calm"] = 1.0 + + if apply_bias: + for k in EMOTION_KEYS: + vec[k] *= EMO_BIAS[k] + + s = sum(vec.values()) + if s > max_sum and s > 0: + scale = max_sum / s + for k in EMOTION_KEYS: + vec[k] *= scale + + return vec, [vec[k] for k in EMOTION_KEYS] + + +@dataclass +class QwenEmotionConfig: + model: str = "Qwen/Qwen2.5-0.5B-Instruct-4bit" + max_tokens: int = 256 + temperature: float = 0.0 + apply_bias: bool = True + + +_LLM_CACHE: Dict[str, Tuple[Any, Any]] = {} + + +class QwenEmotion: + """Emotion-from-text using an MLX LLM (Qwen-family recommended). + + This is MLX-native (mlx_lm) and returns an IndexTTS2-style emotion vector. + """ + + def __init__(self, config: Optional[QwenEmotionConfig] = None): + self.config = config or QwenEmotionConfig() + + # Words that should tilt sad->melancholic (mirrors official helper hack). + self._melancholic_words = { + "低落", + "melancholy", + "melancholic", + "depression", + "depressed", + "gloomy", + } + + def _load_llm(self): + if self.config.model in _LLM_CACHE: + return _LLM_CACHE[self.config.model] + + from mlx_lm.utils import load as load_llm + + llm, tokenizer = load_llm(self.config.model) + _LLM_CACHE[self.config.model] = (llm, tokenizer) + return llm, tokenizer + + def _prompt(self, text: str) -> str: + llm, tokenizer = self._load_llm() + del llm + + system = ( + "You are a text emotion classifier. " + "Return ONLY valid JSON with exactly these keys: " + "happy, angry, sad, afraid, disgusted, melancholic, surprised, calm. " + "Values must be numbers in range [0.0, 1.2]." + ) + + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": text}, + ] + return tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def infer(self, text: str) -> Tuple[Dict[str, float], list[float]]: + llm, tokenizer = self._load_llm() + + from mlx_lm.generate import generate + + prompt = self._prompt(text) + resp = generate( + llm, + tokenizer, + prompt, + max_tokens=self.config.max_tokens, + temp=self.config.temperature, + verbose=False, + ) + + emo = parse_emotion_response(resp) + + # Sad vs melancholic swap workaround. + text_lower = text.lower() + if any(w in text_lower for w in self._melancholic_words): + emo["sad"], emo["melancholic"] = emo.get("melancholic", 0.0), emo.get( + "sad", 0.0 + ) + + return normalize_emo_vector(emo, apply_bias=self.config.apply_bias) diff --git a/mlx_audio/tts/models/indextts2/__init__.py b/mlx_audio/tts/models/indextts2/__init__.py new file mode 100644 index 000000000..41fc62ead --- /dev/null +++ b/mlx_audio/tts/models/indextts2/__init__.py @@ -0,0 +1,4 @@ +from mlx_audio.tts.models.indextts2.config import ModelConfig +from mlx_audio.tts.models.indextts2.indextts2 import Model + +__all__ = ["Model", "ModelConfig"] diff --git a/mlx_audio/tts/models/indextts2/config.py b/mlx_audio/tts/models/indextts2/config.py new file mode 100644 index 000000000..91e971e08 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/config.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass +from typing import Any, Optional + +from mlx_audio.codec.models.bigvgan.bigvgan import BigVGANConfig +from mlx_audio.tts.models.base import BaseModelArgs +from mlx_audio.tts.models.indextts2.semantic_codec import RepCodecConfig + + +@dataclass +class ModelConfig(BaseModelArgs): + """MLX-native IndexTTS2 config. + + This module is a scaffold for a full MLX port of IndexTTS2. + Weight conversion + component implementations will be added incrementally. + """ + + model_type: str = "indextts2" + + # Audio + sample_rate: int = 22050 + + # Optional LLM used to map `emo_text` -> `emo_vector`. + # This is separate from the core TTS weights. + qwen_emotion_model: str = "Qwen/Qwen2.5-0.5B-Instruct-4bit" + + # Vocoder (BigVGAN v2 22khz 80-band for official IndexTTS2) + vocoder: Optional[BigVGANConfig] = None + + # Style encoder (CAMPPlus) config dict (matches CAMPPlus __init__ args) + campplus: Optional[dict[str, Any]] = None + + # Semantic codec (MaskGCT / RepCodec) + semantic_codec: Optional[RepCodecConfig] = None + + # W2V-BERT semantic encoder (facebook/w2v-bert-2.0) + w2vbert: Optional[dict[str, Any]] = None + + # UnifiedVoice (semantic token generator) + unifiedvoice: Optional[dict[str, Any]] = None + + # s2mel flow-matching model + s2mel: Optional[dict[str, Any]] = None + + # Paths within the MLX model folder (relative to model_path) for submodules. + # These will be populated once converters exist. + bigvgan_weights: Optional[str] = None + campplus_weights: Optional[str] = None + maskgct_weights: Optional[str] = None + w2vbert_weights: Optional[str] = None + unifiedvoice_weights: Optional[str] = None + s2mel_diffusion_weights: Optional[str] = None + + # Set by loader + model_path: Optional[str] = None + + @classmethod + def from_dict(cls, config: dict[str, Any]) -> "ModelConfig": + vocoder_cfg = config.get("vocoder", None) + vocoder = ( + BigVGANConfig(**vocoder_cfg) + if isinstance(vocoder_cfg, dict) + else None + ) + + semantic_codec_cfg = config.get("semantic_codec", None) + semantic_codec = ( + RepCodecConfig(**semantic_codec_cfg) + if isinstance(semantic_codec_cfg, dict) + else None + ) + return cls( + model_type=config.get("model_type", "indextts2"), + sample_rate=int(config.get("sample_rate", 22050)), + qwen_emotion_model=config.get( + "qwen_emotion_model", "Qwen/Qwen2.5-0.5B-Instruct-4bit" + ), + vocoder=vocoder, + campplus=config.get("campplus"), + semantic_codec=semantic_codec, + w2vbert=config.get("w2vbert"), + unifiedvoice=config.get("unifiedvoice"), + s2mel=config.get("s2mel"), + bigvgan_weights=config.get("bigvgan_weights"), + campplus_weights=config.get("campplus_weights"), + maskgct_weights=config.get("maskgct_weights"), + w2vbert_weights=config.get("w2vbert_weights"), + unifiedvoice_weights=config.get("unifiedvoice_weights"), + s2mel_diffusion_weights=config.get("s2mel_diffusion_weights"), + model_path=config.get("model_path"), + ) diff --git a/mlx_audio/tts/models/indextts2/convert_bigvgan.py b/mlx_audio/tts/models/indextts2/convert_bigvgan.py new file mode 100644 index 000000000..3d765fd8b --- /dev/null +++ b/mlx_audio/tts/models/indextts2/convert_bigvgan.py @@ -0,0 +1,129 @@ +import argparse +import json +from pathlib import Path +from typing import Any, Dict + + +def _load_torch_checkpoint(path: Path) -> Dict[str, Any]: + import torch + + ckpt = torch.load(str(path), map_location="cpu") + if not isinstance(ckpt, dict): + raise ValueError(f"Unexpected checkpoint type: {type(ckpt)}") + if "generator" in ckpt and isinstance(ckpt["generator"], dict): + return ckpt["generator"] + # Some checkpoints may be raw state_dict + return ckpt + + +def main(): + parser = argparse.ArgumentParser( + description="Convert NVIDIA BigVGAN v2 generator to MLX safetensors (IndexTTS2 vocoder)" + ) + parser.add_argument( + "--hf-repo", + type=str, + default="nvidia/bigvgan_v2_22khz_80band_256x", + help="HuggingFace repo id of the BigVGAN model", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HF revision/commit", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory containing config.json + model.safetensors", + ) + args = parser.parse_args() + + from huggingface_hub import hf_hub_download + import mlx.core as mx + + from mlx_audio.codec.models.bigvgan.bigvgan import BigVGAN, BigVGANConfig + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + cfg_path = Path( + hf_hub_download( + repo_id=args.hf_repo, + filename="config.json", + revision=args.revision, + ) + ) + gen_path = Path( + hf_hub_download( + repo_id=args.hf_repo, + filename="bigvgan_generator.pt", + revision=args.revision, + ) + ) + + cfg = json.loads(cfg_path.read_text(encoding="utf-8")) + vocoder_cfg = BigVGANConfig( + num_mels=int(cfg["num_mels"]), + upsample_rates=list(map(int, cfg["upsample_rates"])), + upsample_kernel_sizes=list(map(int, cfg["upsample_kernel_sizes"])), + upsample_initial_channel=int(cfg["upsample_initial_channel"]), + resblock=str(cfg["resblock"]), + resblock_kernel_sizes=list(map(int, cfg["resblock_kernel_sizes"])), + resblock_dilation_sizes=cfg["resblock_dilation_sizes"], + activation=str(cfg["activation"]), + snake_logscale=bool(cfg["snake_logscale"]), + use_bias_at_final=bool(cfg.get("use_bias_at_final", True)), + use_tanh_at_final=bool(cfg.get("use_tanh_at_final", True)), + ) + + state = _load_torch_checkpoint(gen_path) + + # Convert tensors to mx arrays. + mx_weights = {} + for k, v in state.items(): + if hasattr(v, "detach"): + v = v.detach().cpu().numpy() + mx_weights[k] = mx.array(v) + + # Sanitize (transpose conv kernels for MLX layout where needed). + model = BigVGAN(vocoder_cfg) + mx_weights = model.sanitize(mx_weights) + + # Prefix for embedding under IndexTTS2 Model(bigvgan=...) + mx_weights = {f"bigvgan.{k}": v for k, v in mx_weights.items()} + + # Save weights + mx.save_safetensors(str(out_dir / "bigvgan.safetensors"), mx_weights) + + # Merge into (or create) an mlx-audio config.json + cfg_out_path = out_dir / "config.json" + if cfg_out_path.exists(): + out_cfg = json.loads(cfg_out_path.read_text(encoding="utf-8")) + else: + out_cfg = {} + + out_cfg["model_type"] = "indextts2" + out_cfg.setdefault("sample_rate", int(cfg.get("sampling_rate", 22050))) + out_cfg["vocoder"] = { + "num_mels": vocoder_cfg.num_mels, + "upsample_rates": vocoder_cfg.upsample_rates, + "upsample_kernel_sizes": vocoder_cfg.upsample_kernel_sizes, + "upsample_initial_channel": vocoder_cfg.upsample_initial_channel, + "resblock": vocoder_cfg.resblock, + "resblock_kernel_sizes": vocoder_cfg.resblock_kernel_sizes, + "resblock_dilation_sizes": vocoder_cfg.resblock_dilation_sizes, + "activation": vocoder_cfg.activation, + "snake_logscale": vocoder_cfg.snake_logscale, + "use_bias_at_final": vocoder_cfg.use_bias_at_final, + "use_tanh_at_final": vocoder_cfg.use_tanh_at_final, + } + cfg_out_path.write_text(json.dumps(out_cfg, indent=2), encoding="utf-8") + + print(f"Saved MLX weights: {out_dir / 'bigvgan.safetensors'}") + print(f"Saved config: {out_dir / 'config.json'}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/models/indextts2/convert_campplus.py b/mlx_audio/tts/models/indextts2/convert_campplus.py new file mode 100644 index 000000000..c268679f6 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/convert_campplus.py @@ -0,0 +1,104 @@ +import argparse +import json +from pathlib import Path +from typing import Any, Dict + + +def _load_state_dict(path: Path) -> Dict[str, Any]: + import torch + + sd = torch.load(str(path), map_location="cpu") + if not isinstance(sd, dict): + raise ValueError(f"Unexpected checkpoint type: {type(sd)}") + return sd + + +def main(): + parser = argparse.ArgumentParser( + description="Convert funasr/campplus CAMPPlus style encoder to MLX safetensors (IndexTTS2)" + ) + parser.add_argument( + "--hf-repo", + type=str, + default="funasr/campplus", + help="HuggingFace repo id containing campplus_cn_common.bin", + ) + parser.add_argument( + "--filename", + type=str, + default="campplus_cn_common.bin", + help="Checkpoint filename in the HF repo", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HF revision/commit", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory containing config.json + campplus.safetensors", + ) + parser.add_argument( + "--feat-dim", + type=int, + default=80, + help="Input feature dimension (IndexTTS2 uses 80)", + ) + parser.add_argument( + "--embedding-size", + type=int, + default=192, + help="Output embedding size (IndexTTS2 uses 192)", + ) + args = parser.parse_args() + + from huggingface_hub import hf_hub_download + import mlx.core as mx + + from mlx_audio.tts.models.chatterbox.s3gen.xvector import CAMPPlus + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + ckpt_path = Path( + hf_hub_download( + repo_id=args.hf_repo, + filename=args.filename, + revision=args.revision, + ) + ) + + state = _load_state_dict(ckpt_path) + + # Convert tensors to mx arrays. + mx_weights = {} + for k, v in state.items(): + if hasattr(v, "detach"): + v = v.detach().cpu().numpy() + mx_weights[k] = mx.array(v) + + camp = CAMPPlus(feat_dim=args.feat_dim, embedding_size=args.embedding_size) + mx_weights = camp.sanitize(mx_weights) + mx_weights = {f"campplus.{k}": v for k, v in mx_weights.items()} + + mx.save_safetensors(str(out_dir / "campplus.safetensors"), mx_weights) + + # Update or create config.json + cfg_path = out_dir / "config.json" + if cfg_path.exists(): + cfg = json.loads(cfg_path.read_text(encoding="utf-8")) + else: + cfg = {"model_type": "indextts2", "sample_rate": 22050} + + cfg["campplus"] = {"feat_dim": args.feat_dim, "embedding_size": args.embedding_size} + cfg_path.write_text(json.dumps(cfg, indent=2), encoding="utf-8") + + print(f"Saved MLX weights: {out_dir / 'campplus.safetensors'}") + print(f"Saved config: {out_dir / 'config.json'}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/models/indextts2/convert_maskgct_semantic_codec.py b/mlx_audio/tts/models/indextts2/convert_maskgct_semantic_codec.py new file mode 100644 index 000000000..4a730c7b2 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/convert_maskgct_semantic_codec.py @@ -0,0 +1,90 @@ +import argparse +import json +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description="Convert MaskGCT semantic codec (RepCodec) weights to MLX safetensors" + ) + parser.add_argument( + "--hf-repo", + type=str, + default="amphion/MaskGCT", + help="HuggingFace repo id containing semantic_codec/model.safetensors", + ) + parser.add_argument( + "--filename", + type=str, + default="semantic_codec/model.safetensors", + help="Path to the safetensors file within the HF repo", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HF revision/commit", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory containing config.json + semantic_codec.safetensors", + ) + args = parser.parse_args() + + from huggingface_hub import hf_hub_download + import mlx.core as mx + from safetensors import safe_open + + from mlx_audio.tts.models.indextts2.semantic_codec import RepCodec, RepCodecConfig + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + st_path = Path( + hf_hub_download( + repo_id=args.hf_repo, + filename=args.filename, + revision=args.revision, + ) + ) + + weights = {} + with safe_open(str(st_path), framework="numpy") as f: + for k in f.keys(): + weights[k] = mx.array(f.get_tensor(k)) + + # Default config matches IndexTTS2 official config.yaml + cfg = RepCodecConfig() + model = RepCodec(cfg) + + weights = model.sanitize(weights) + weights = {f"semantic_codec.{k}": v for k, v in weights.items()} + + mx.save_safetensors(str(out_dir / "semantic_codec.safetensors"), weights) + + cfg_path = out_dir / "config.json" + if cfg_path.exists(): + cfg_json = json.loads(cfg_path.read_text(encoding="utf-8")) + else: + cfg_json = {"model_type": "indextts2", "sample_rate": 22050} + + cfg_json["semantic_codec"] = { + "codebook_size": cfg.codebook_size, + "hidden_size": cfg.hidden_size, + "codebook_dim": cfg.codebook_dim, + "vocos_dim": cfg.vocos_dim, + "vocos_intermediate_dim": cfg.vocos_intermediate_dim, + "vocos_num_layers": cfg.vocos_num_layers, + "num_quantizers": cfg.num_quantizers, + "downsample_scale": cfg.downsample_scale, + } + cfg_path.write_text(json.dumps(cfg_json, indent=2), encoding="utf-8") + + print(f"Saved MLX weights: {out_dir / 'semantic_codec.safetensors'}") + print(f"Saved config: {out_dir / 'config.json'}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/models/indextts2/convert_s2mel.py b/mlx_audio/tts/models/indextts2/convert_s2mel.py new file mode 100644 index 000000000..1cc7215b3 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/convert_s2mel.py @@ -0,0 +1,172 @@ +import argparse +import json +from pathlib import Path +from typing import Dict, Tuple + + +def _linear_strip_weightnorm(weight_g, weight_v, eps: float = 1e-8): + # w = g * v / ||v|| + import numpy as np + + v = weight_v + norm = np.sqrt(np.sum(v * v, axis=1, keepdims=True) + eps) + g = weight_g.reshape(-1, 1) + return v * (g / norm) + + +def _conv1d_strip_weightnorm(weight_g, weight_v, eps: float = 1e-8): + # torch conv1d weight_v: (O, I, K) + import numpy as np + + v = weight_v + norm = np.sqrt(np.sum(v * v, axis=(1, 2), keepdims=True) + eps) + g = weight_g.reshape(-1, 1, 1) + return v * (g / norm) + + +def _transpose_conv1d_torch_to_mlx(w): + # (O, I, K) -> (O, K, I) + return w.transpose(0, 2, 1) + + +def _transpose_conv2d_torch_to_mlx(w): + # (O, I, KH, KW) -> (O, KH, KW, I) + return w.transpose(0, 2, 3, 1) + + +def main(): + parser = argparse.ArgumentParser(description="Convert IndexTTS2 s2mel.pth to MLX safetensors") + parser.add_argument( + "--hf-repo", + type=str, + default="IndexTeam/IndexTTS-2", + help="HuggingFace repo id containing s2mel.pth and config.yaml", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HF revision/commit", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory containing config.json + s2mel.safetensors", + ) + args = parser.parse_args() + + from huggingface_hub import hf_hub_download + import torch + import mlx.core as mx + import yaml + + from mlx_audio.tts.models.indextts2.s2mel import S2MelConfig, S2MelModel + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + cfg_path = Path(hf_hub_download(args.hf_repo, "config.yaml", revision=args.revision)) + ckpt_path = Path(hf_hub_download(args.hf_repo, "s2mel.pth", revision=args.revision)) + + cfg_yaml = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) + s2mel_cfg = S2MelConfig.from_dict(cfg_yaml["s2mel"]) + model = S2MelModel(s2mel_cfg) + + obj = torch.load(str(ckpt_path), map_location="cpu") + sd = obj["net"] + + # Flatten into one dict matching MLX parameter names. + # MLX expects weights for: + # - s2mel.cfm.estimator.* + # - s2mel.length_regulator.* + # - s2mel.gpt_layer.* + # The torch checkpoint stores: cfm.*, length_regulator.*, gpt_layer.* + + out: Dict[str, mx.array] = {} + + def add_module(prefix_out: str, module_sd: Dict[str, torch.Tensor]): + # Normalize key names to match our MLX modules (strip SConv1d wrappers) + norm_sd: Dict[str, torch.Tensor] = {} + for k, v in module_sd.items(): + nk = k.replace(".conv.conv.", ".") + norm_sd[nk] = v + module_sd = norm_sd + + # Handle weightnorm pairs + used = set() + for k, v in module_sd.items(): + if k.endswith(".weight_g"): + base = k[: -len(".weight_g")] + v_key = base + ".weight_v" + if v_key not in module_sd: + continue + w_g = module_sd[k].detach().cpu().numpy() + w_v = module_sd[v_key].detach().cpu().numpy() + if w_v.ndim == 2: + w = _linear_strip_weightnorm(w_g, w_v) + elif w_v.ndim == 3: + w = _conv1d_strip_weightnorm(w_g, w_v) + elif w_v.ndim == 4: + # rare + w = w_v + else: + w = w_v + if w.ndim == 3: + w = _transpose_conv1d_torch_to_mlx(w) + elif w.ndim == 4: + w = _transpose_conv2d_torch_to_mlx(w) + out[prefix_out + base + ".weight"] = mx.array(w) + used.add(k) + used.add(v_key) + + for k, v in module_sd.items(): + if k in used: + continue + arr = v.detach().cpu().numpy() if hasattr(v, "detach") else v + if getattr(arr, "ndim", 0) == 3: + arr = _transpose_conv1d_torch_to_mlx(arr) + elif getattr(arr, "ndim", 0) == 4: + arr = _transpose_conv2d_torch_to_mlx(arr) + out[prefix_out + k] = mx.array(arr) + + add_module("s2mel.cfm.", sd["cfm"]) + add_module("s2mel.length_regulator.", sd["length_regulator"]) + add_module("s2mel.gpt_layer.", sd["gpt_layer"]) + + # Add derived / MLX-only parameters that are not stored in torch ckpt + # 1) freqs_cis buffer for GPTFastTransformer + head_dim = int(cfg_yaml["s2mel"]["DiT"].get("hidden_dim", 512)) // int( + cfg_yaml["s2mel"]["DiT"].get("num_heads", 8) + ) + seq_len = 16384 + base = 10000 + import numpy as np + freqs = 1.0 / (base ** (np.arange(0, head_dim, 2, dtype=np.float32) / head_dim)) + t = np.arange(seq_len, dtype=np.float32) + outer = np.outer(t, freqs) + freqs_cis = np.stack([np.cos(outer), np.sin(outer)], axis=-1) + out["s2mel.cfm.estimator.transformer.freqs_cis"] = mx.array(freqs_cis) + + # 2) FinalLayer LayerNorm affine params (MLX LayerNorm always has them) + out["s2mel.cfm.estimator.final_layer.norm_final.weight"] = mx.ones((512,), dtype=mx.float32) + out["s2mel.cfm.estimator.final_layer.norm_final.bias"] = mx.zeros((512,), dtype=mx.float32) + + mx.save_safetensors(str(out_dir / "s2mel.safetensors"), out) + + # Merge config.json + cfg_out = out_dir / "config.json" + if cfg_out.exists(): + root = json.loads(cfg_out.read_text(encoding="utf-8")) + else: + root = {"model_type": "indextts2", "sample_rate": 22050} + + root["s2mel"] = cfg_yaml["s2mel"] + cfg_out.write_text(json.dumps(root, indent=2), encoding="utf-8") + + print(f"Saved MLX weights: {out_dir / 's2mel.safetensors'}") + print(f"Saved config: {out_dir / 'config.json'}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/models/indextts2/convert_unifiedvoice.py b/mlx_audio/tts/models/indextts2/convert_unifiedvoice.py new file mode 100644 index 000000000..c94246ea8 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/convert_unifiedvoice.py @@ -0,0 +1,94 @@ +import argparse +import json +from pathlib import Path + + +def _transpose_conv1d(w): + # torch (O, I, K) -> mlx (O, K, I) + return w.transpose(0, 2, 1) + + +def _transpose_conv2d(w): + # torch (O, I, KH, KW) -> mlx (O, KH, KW, I) + return w.transpose(0, 2, 3, 1) + + +def main(): + parser = argparse.ArgumentParser(description="Convert IndexTTS2 gpt.pth (UnifiedVoice) to MLX safetensors") + parser.add_argument( + "--hf-repo", + type=str, + default="IndexTeam/IndexTTS-2", + help="HuggingFace repo id containing gpt.pth and config.yaml", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HF revision/commit", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory containing config.json + unifiedvoice.safetensors + bpe.model", + ) + args = parser.parse_args() + + from huggingface_hub import hf_hub_download + import torch + import mlx.core as mx + import yaml + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + cfg_path = Path(hf_hub_download(args.hf_repo, "config.yaml", revision=args.revision)) + ckpt_path = Path(hf_hub_download(args.hf_repo, "gpt.pth", revision=args.revision)) + bpe_path = Path(hf_hub_download(args.hf_repo, "bpe.model", revision=args.revision)) + + cfg_yaml = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) + gpt_cfg = cfg_yaml["gpt"] + + # Copy bpe.model alongside config for local loading + (out_dir / "bpe.model").write_bytes(bpe_path.read_bytes()) + + sd = torch.load(str(ckpt_path), map_location="cpu") + + weights = {} + for k, v in sd.items(): + arr = v.detach().cpu().numpy() if hasattr(v, "detach") else v + + # Conv weights + if arr.ndim == 3 and ("conv" in k or "depthwise_conv" in k or "pointwise_conv" in k): + arr = _transpose_conv1d(arr) + if arr.ndim == 4 and "conv" in k: + arr = _transpose_conv2d(arr) + + # GPT2 (mlx-lm) expects transposed Linear weights in several places. + if ".attn.c_attn.weight" in k or ".attn.c_proj.weight" in k or ".mlp.c_fc.weight" in k or ".mlp.c_proj.weight" in k: + if arr.ndim == 2: + arr = arr.transpose(1, 0) + + weights[f"unifiedvoice.{k}"] = mx.array(arr) + + mx.save_safetensors(str(out_dir / "unifiedvoice.safetensors"), weights) + + # Merge config.json + cfg_out = out_dir / "config.json" + if cfg_out.exists(): + root = json.loads(cfg_out.read_text(encoding="utf-8")) + else: + root = {"model_type": "indextts2", "sample_rate": 22050} + + root["unifiedvoice"] = gpt_cfg + root["unifiedvoice"]["bpe_model"] = "bpe.model" + cfg_out.write_text(json.dumps(root, indent=2), encoding="utf-8") + + print(f"Saved MLX weights: {out_dir / 'unifiedvoice.safetensors'}") + print(f"Saved BPE: {out_dir / 'bpe.model'}") + print(f"Saved config: {out_dir / 'config.json'}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/models/indextts2/convert_w2vbert.py b/mlx_audio/tts/models/indextts2/convert_w2vbert.py new file mode 100644 index 000000000..80a9ef0a1 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/convert_w2vbert.py @@ -0,0 +1,107 @@ +import argparse +import json +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description="Convert facebook/w2v-bert-2.0 to an MLX-compatible safetensors for IndexTTS2" + ) + parser.add_argument( + "--hf-repo", + type=str, + default="facebook/w2v-bert-2.0", + help="HuggingFace repo id", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HF revision/commit", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory containing config.json + w2vbert.safetensors", + ) + args = parser.parse_args() + + from huggingface_hub import hf_hub_download + import mlx.core as mx + from safetensors import safe_open + + from mlx_audio.tts.models.indextts2.w2vbert import Wav2Vec2BertConfig, Wav2Vec2BertModel + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + cfg_path = Path( + hf_hub_download( + repo_id=args.hf_repo, + filename="config.json", + revision=args.revision, + ) + ) + model_path = Path( + hf_hub_download( + repo_id=args.hf_repo, + filename="model.safetensors", + revision=args.revision, + ) + ) + + cfg_json = json.loads(cfg_path.read_text(encoding="utf-8")) + cfg = Wav2Vec2BertConfig( + hidden_size=int(cfg_json["hidden_size"]), + num_hidden_layers=int(cfg_json["num_hidden_layers"]), + num_attention_heads=int(cfg_json["num_attention_heads"]), + intermediate_size=int(cfg_json["intermediate_size"]), + feature_projection_input_dim=int(cfg_json["feature_projection_input_dim"]), + layer_norm_eps=float(cfg_json.get("layer_norm_eps", 1e-5)), + position_embeddings_type=cfg_json.get("position_embeddings_type", "relative_key"), + rotary_embedding_base=int(cfg_json.get("rotary_embedding_base", 10000)), + max_source_positions=int(cfg_json.get("max_source_positions", 5000)), + left_max_position_embeddings=int(cfg_json.get("left_max_position_embeddings", 64)), + right_max_position_embeddings=int(cfg_json.get("right_max_position_embeddings", 8)), + conv_depthwise_kernel_size=int(cfg_json.get("conv_depthwise_kernel_size", 31)), + conformer_conv_dropout=float(cfg_json.get("conformer_conv_dropout", 0.1)), + hidden_dropout=float(cfg_json.get("hidden_dropout", 0.0)), + activation_dropout=float(cfg_json.get("activation_dropout", 0.0)), + attention_dropout=float(cfg_json.get("attention_dropout", 0.0)), + feat_proj_dropout=float(cfg_json.get("feat_proj_dropout", 0.0)), + ) + + model = Wav2Vec2BertModel(cfg) + + weights = {} + # Use numpy framework to avoid requiring torch at conversion time. + with safe_open(str(model_path), framework="numpy") as f: + for k in f.keys(): + weights[k] = mx.array(f.get_tensor(k)) + + weights = model.sanitize(weights) + weights = {f"w2vbert.{k}": v for k, v in weights.items()} + + mx.save_safetensors(str(out_dir / "w2vbert.safetensors"), weights) + + cfg_out_path = out_dir / "config.json" + if cfg_out_path.exists(): + root_cfg = json.loads(cfg_out_path.read_text(encoding="utf-8")) + else: + root_cfg = {"model_type": "indextts2", "sample_rate": 22050} + + root_cfg["w2vbert"] = { + "hf_repo": args.hf_repo, + "config": cfg_json, + "weights": "w2vbert.safetensors", + "prefix": "w2vbert.", + } + cfg_out_path.write_text(json.dumps(root_cfg, indent=2), encoding="utf-8") + + print(f"Saved MLX weights: {out_dir / 'w2vbert.safetensors'}") + print(f"Saved config: {out_dir / 'config.json'}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/models/indextts2/convert_w2vbert_stats.py b/mlx_audio/tts/models/indextts2/convert_w2vbert_stats.py new file mode 100644 index 000000000..9be090652 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/convert_w2vbert_stats.py @@ -0,0 +1,61 @@ +import argparse +import json +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description="Convert wav2vec2-bert stats (mean/var) to MLX safetensors (IndexTTS2)" + ) + parser.add_argument( + "--stats-pt", + type=str, + required=True, + help="Path to wav2vec2bert_stats.pt (torch file with mean/var)", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory containing config.json + w2vbert_stats.safetensors", + ) + args = parser.parse_args() + + import torch + import mlx.core as mx + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + stats = torch.load(args.stats_pt, map_location="cpu") + if not isinstance(stats, dict) or "mean" not in stats or "var" not in stats: + raise ValueError("Expected a dict with keys 'mean' and 'var'") + + mean = stats["mean"].detach().cpu().numpy() + var = stats["var"].detach().cpu().numpy() + + weights = { + "w2vbert_stats.mean": mx.array(mean), + "w2vbert_stats.std": mx.sqrt(mx.array(var)), + } + + mx.save_safetensors(str(out_dir / "w2vbert_stats.safetensors"), weights) + + cfg_path = out_dir / "config.json" + if cfg_path.exists(): + cfg = json.loads(cfg_path.read_text(encoding="utf-8")) + else: + cfg = {"model_type": "indextts2", "sample_rate": 22050} + + cfg["w2vbert_stats"] = { + "mean": "w2vbert_stats.safetensors::w2vbert_stats.mean", + "std": "w2vbert_stats.safetensors::w2vbert_stats.std", + } + cfg_path.write_text(json.dumps(cfg, indent=2), encoding="utf-8") + + print(f"Saved stats: {out_dir / 'w2vbert_stats.safetensors'}") + print(f"Saved config: {out_dir / 'config.json'}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/tts/models/indextts2/indextts2.py b/mlx_audio/tts/models/indextts2/indextts2.py new file mode 100644 index 000000000..db9eb9bdf --- /dev/null +++ b/mlx_audio/tts/models/indextts2/indextts2.py @@ -0,0 +1,362 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from mlx_audio.codec.models.bigvgan.bigvgan import BigVGAN +from mlx_audio.tts.indextts2.emotion import QwenEmotion, QwenEmotionConfig +from mlx_audio.tts.models.base import GenerationResult, adjust_speed +from mlx_audio.utils import load_audio +from mlx_audio.dsp import compute_fbank_kaldi, mel_filters, stft + +from .config import ModelConfig +from .semantic_codec import RepCodec +from .w2vbert_features import W2VBertFeatureExtractor, W2VBertFeatureExtractorConfig +from .w2vbert_stats import W2VBertStats +from .w2vbert import Wav2Vec2BertConfig, Wav2Vec2BertModel +from .unifiedvoice import UnifiedVoice, UnifiedVoiceConfig +from .s2mel import S2MelConfig, S2MelModel + + +class Model(nn.Module): + """MLX-native IndexTTS2 (scaffold). + + The full pipeline (w2v-bert -> MaskGCT -> UnifiedVoice -> s2mel diffusion -> BigVGAN) + is implemented step-by-step. For now this class wires up the public generate() + interface and emotion-from-text support. + """ + + def __init__(self, config: Union[ModelConfig, dict]): + super().__init__() + if isinstance(config, dict): + config = ModelConfig.from_dict(config) + self.config = config + + self.model_type = config.model_type + self.sample_rate = config.sample_rate + + self._emotion: Optional[QwenEmotion] = None + + self.bigvgan: Optional[BigVGAN] = ( + BigVGAN(config.vocoder) if config.vocoder is not None else None + ) + + self.campplus = None + if config.campplus is not None: + # Reuse the existing MLX CAMPPlus implementation. + from mlx_audio.tts.models.chatterbox.s3gen.xvector import CAMPPlus + + self.campplus = CAMPPlus(**config.campplus) + + self.semantic_codec = None + if config.semantic_codec is not None: + self.semantic_codec = RepCodec(config.semantic_codec) + + # W2V-BERT feature pipeline (semantic encoder to be implemented next) + self.w2vbert_feature_extractor = W2VBertFeatureExtractor( + W2VBertFeatureExtractorConfig() + ) + self.w2vbert_stats = W2VBertStats(dim=1024) + self.w2vbert = None + if getattr(config, "w2vbert", None) and isinstance(config.w2vbert, dict): + # Expect `w2vbert.config` to be the HF config.json dict. + cfg = config.w2vbert.get("config") + if isinstance(cfg, dict): + self.w2vbert = Wav2Vec2BertModel( + Wav2Vec2BertConfig( + hidden_size=int(cfg["hidden_size"]), + num_hidden_layers=int(cfg["num_hidden_layers"]), + num_attention_heads=int(cfg["num_attention_heads"]), + intermediate_size=int(cfg["intermediate_size"]), + feature_projection_input_dim=int(cfg["feature_projection_input_dim"]), + layer_norm_eps=float(cfg.get("layer_norm_eps", 1e-5)), + position_embeddings_type=cfg.get( + "position_embeddings_type", "relative_key" + ), + rotary_embedding_base=int(cfg.get("rotary_embedding_base", 10000)), + max_source_positions=int(cfg.get("max_source_positions", 5000)), + left_max_position_embeddings=int( + cfg.get("left_max_position_embeddings", 64) + ), + right_max_position_embeddings=int( + cfg.get("right_max_position_embeddings", 8) + ), + conv_depthwise_kernel_size=int( + cfg.get("conv_depthwise_kernel_size", 31) + ), + conformer_conv_dropout=float( + cfg.get("conformer_conv_dropout", 0.1) + ), + ) + ) + + self.unifiedvoice = None + if getattr(config, "unifiedvoice", None) and isinstance(config.unifiedvoice, dict): + bpe_model = config.unifiedvoice.get("bpe_model", "bpe.model") + bpe_path = None + if config.model_path is not None: + bpe_path = str((Path(config.model_path) / bpe_model).resolve()) + self.unifiedvoice = UnifiedVoice( + UnifiedVoiceConfig.from_dict(config.unifiedvoice), + bpe_model=bpe_path or bpe_model, + ) + + self.s2mel = None + if getattr(config, "s2mel", None) and isinstance(config.s2mel, dict): + self.s2mel = S2MelModel(S2MelConfig.from_dict(config.s2mel)) + + # TODO: instantiate submodules once implemented and weights are available. + + def _get_emotion(self) -> QwenEmotion: + if self._emotion is None: + self._emotion = QwenEmotion( + QwenEmotionConfig(model=self.config.qwen_emotion_model) + ) + return self._emotion + + def _result(self, audio: mx.array, start_time: float) -> GenerationResult: + samples = int(audio.shape[0]) + audio_duration_seconds = samples / self.sample_rate + elapsed_time = time.perf_counter() - start_time + rtf = (audio_duration_seconds / elapsed_time) if elapsed_time > 0 else 0.0 + + duration_mins = int(audio_duration_seconds // 60) + duration_secs = int(audio_duration_seconds % 60) + duration_ms = int((audio_duration_seconds % 1) * 1000) + duration_hours = int(audio_duration_seconds // 3600) + duration_str = ( + f"{duration_hours:02d}:{duration_mins:02d}:{duration_secs:02d}.{duration_ms:03d}" + ) + + return GenerationResult( + audio=audio, + samples=samples, + sample_rate=self.sample_rate, + segment_idx=0, + token_count=0, + audio_duration=duration_str, + real_time_factor=rtf, + prompt={"tokens": 0, "tokens-per-sec": 0}, + audio_samples={ + "samples": samples, + "samples-per-sec": (round(samples / elapsed_time, 2) if elapsed_time > 0 else 0), + }, + processing_time_seconds=elapsed_time, + peak_memory_usage=mx.get_peak_memory() / 1e9, + ) + + def _s2mel_ref_mel(self, audio: mx.array) -> mx.array: + """Match official IndexTTS2 mel_spectrogram settings for s2mel reference.""" + n_fft = 1024 + hop = 256 + n_mels = 80 + + pad = int((n_fft - hop) / 2) + prefix = audio[1 : pad + 1][::-1] + suffix = audio[-(pad + 1) : -1][::-1] + y = mx.concatenate([prefix, audio, suffix], axis=0) + + spec = stft( + y, + n_fft=n_fft, + hop_length=hop, + win_length=n_fft, + window="hann", + center=False, + pad_mode="reflect", + ).abs() + + fb = mel_filters( + sample_rate=self.sample_rate, + n_fft=n_fft, + n_mels=n_mels, + f_min=0.0, + f_max=None, + norm="slaney", + mel_scale="slaney", + ) + mel = spec @ fb.T + mel = mx.log(mx.maximum(mel, 1e-5)) + return mel.T[None, :, :] + + def _align_generated_mel_to_prompt(self, mel: mx.array, ref_mel: mx.array) -> mx.array: + """Match generated mel stats to prompt mel stats. + + This mirrors a common stabilization trick for flow vocoder pipelines when the + generated mel drifts to an overly low-energy range that leads to near-silent + waveform output. + """ + if mel.shape[-1] < 2: + return mel + + ref_mean = mx.mean(ref_mel, axis=-1, keepdims=True) + ref_std = mx.std(ref_mel, axis=-1, keepdims=True) + mel_mean = mx.mean(mel, axis=-1, keepdims=True) + mel_std = mx.std(mel, axis=-1, keepdims=True) + + eps = 1e-5 + mel_n = (mel - mel_mean) / mx.maximum(mel_std, eps) + mel_a = mel_n * ref_std + ref_mean + + # Keep values in a reasonable log-mel range for BigVGAN. + return mx.clip(mel_a, -12.0, 4.0) + + def generate( + self, + text: str, + *, + ref_audio: Optional[Union[str, mx.array]] = None, + ref_text: Optional[str] = None, + speed: float = 1.0, + # Emotion controls + use_emo_text: bool = False, + emo_text: Optional[str] = None, + emo_vector: Optional[list[float]] = None, + emo_alpha: float = 1.0, + repetition_penalty: float = 10.0, + diffusion_steps: int = 40, + diffusion_cfg_rate: float = 0.7, + # Keep signature compatible with mlx_audio.tts.generate + voice: Optional[str] = None, + lang_code: str = "en", + verbose: bool = False, + stream: bool = False, + streaming_interval: float = 2.0, + **kwargs, + ) -> Iterator[GenerationResult]: + del voice, lang_code, ref_text, verbose, stream, streaming_interval, kwargs + + if ref_audio is None: + raise ValueError("IndexTTS2 requires ref_audio (speaker prompt audio)") + + if self.unifiedvoice is None or self.s2mel is None or self.semantic_codec is None: + raise ValueError( + "IndexTTS2 is missing required submodules (unifiedvoice/s2mel/semantic_codec). " + "Make sure you converted all weights into the model folder." + ) + + start_time = time.perf_counter() + + # Load reference audio at 16k (semantic) and 22.05k (mel/vocoder) + ref_16k = load_audio(ref_audio, sample_rate=16000) + ref_22k = load_audio(ref_audio, sample_rate=self.sample_rate) + + # W2V-BERT features -> hidden states + input_features, attn_mask = self.w2vbert_feature_extractor(ref_16k) + last, hstates = self.w2vbert( + input_features, attention_mask=attn_mask, output_hidden_states=True + ) + del last + hs17 = self.w2vbert_stats(hstates[17]) + + # Semantic codec prompt codes + embeddings + ref_codes, ref_quant = self.semantic_codec.quantize(hs17) + # ref_mel: (B, 80, T) + ref_mel = self._s2mel_ref_mel(ref_22k) + ref_mel_len = mx.array([ref_mel.shape[-1]], dtype=mx.int32) + + # Style from CAMPPlus (Kaldi fbank) + fb = compute_fbank_kaldi(ref_16k, sample_rate=16000, num_mels=80, dither=0.0) + fb = fb - mx.mean(fb, axis=0, keepdims=True) + style = self.campplus(fb[None, :, :]) + + prompt_condition, _, _, _, _ = self.s2mel.length_regulator( + ref_quant, ylens=ref_mel_len + ) + + # Emotion vector + emo_vec = None + if use_emo_text and emo_vector is None: + text_for_emo = emo_text if emo_text is not None else text + _, emo_vector = self._get_emotion().infer(text_for_emo) + if emo_vector is not None: + emo_vec = mx.array(emo_vector, dtype=mx.float32)[None, :] + + # Text -> semantic codes via UnifiedVoice + text_tokens = self.unifiedvoice.encode_text(text) + + # Use reference hidden states as speaker/emotion condition. + # UnifiedVoice expects (B, T, 1024) + spk_cond = hs17 + emo_cond = hs17 + + codes, speech_latent = self.unifiedvoice.inference_speech( + spk_cond, + text_tokens, + emo_cond, + alpha=float(emo_alpha), + top_p=0.8, + top_k=30, + temperature=0.8, + max_generate_length=1500, + repetition_penalty=float(repetition_penalty), + ) + + # Strip stop token if present + if codes.shape[1] > 0 and int(codes[0, -1].item()) == self.unifiedvoice.cfg.stop_mel_token: + codes = codes[:, :-1] + + # Get GPT latent and project to 1024 + emo_vec_lat = self.unifiedvoice.get_emovec(emo_cond, mx.array([emo_cond.shape[1]], dtype=mx.int32)) + mel_lat = self.unifiedvoice.forward_latent(speech_latent, text_tokens, codes, emo_vec_lat) + gpt_lat = self.s2mel.project_gpt_latent(mel_lat) + + # Semantic embedding of inferred codes + S_infer = self.semantic_codec.vq2emb(codes[None, :, :]) + S_infer = S_infer + gpt_lat + + code_lens = mx.array([S_infer.shape[1]], dtype=mx.int32) + target_lengths = (code_lens.astype(mx.float32) * 1.72).astype(mx.int32) + + cond, _, _, _, _ = self.s2mel.length_regulator(S_infer, ylens=target_lengths) + cat_condition = mx.concatenate([prompt_condition, cond], axis=1) + + x_lens = mx.array([cat_condition.shape[1]], dtype=mx.int32) + mel_all = self.s2mel.cfm.inference( + cat_condition, + x_lens, + ref_mel, + style, + None, + n_timesteps=int(diffusion_steps), + inference_cfg_rate=float(diffusion_cfg_rate), + ) + mel = mel_all[:, :, ref_mel.shape[-1] :] + + if mel.shape[-1] == 0: + raise ValueError( + "IndexTTS2 generated empty mel sequence (no semantic frames after stop token)." + ) + + # Auto-correct low-energy mel drift before vocoding. + ref_energy = mx.mean(mx.std(ref_mel, axis=-1)) + mel_energy = mx.mean(mx.std(mel, axis=-1)) + ref_level = mx.mean(mx.mean(ref_mel, axis=-1)) + mel_level = mx.mean(mx.mean(mel, axis=-1)) + if float(mel_energy.item()) < float((ref_energy * 0.75).item()) or float( + mel_level.item() + ) < float((ref_level - 1.5).item()): + mel = self._align_generated_mel_to_prompt(mel, ref_mel) + + audio = self.bigvgan(mel) + audio = audio.reshape(-1).astype(mx.float32) + if speed and speed != 1.0: + audio = adjust_speed(audio, speed) + + yield self._result(audio, start_time) + + def vocode(self, mel: mx.array) -> mx.array: + """Run the BigVGAN vocoder. + + Args: + mel: (B, n_mels, T) float mel-spectrogram. + """ + if self.bigvgan is None: + raise ValueError("BigVGAN vocoder is not configured/loaded") + audio = self.bigvgan(mel) + return audio diff --git a/mlx_audio/tts/models/indextts2/s2mel.py b/mlx_audio/tts/models/indextts2/s2mel.py new file mode 100644 index 000000000..fcf4e5583 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/s2mel.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import mlx.nn as nn + +from .s2mel_dit import DiTConfig +from .s2mel_flow_matching import CFM +from .s2mel_length_regulator import InterpolateRegulator, InterpolateRegulatorConfig + + +@dataclass +class S2MelConfig: + dit: DiTConfig + length_regulator: InterpolateRegulatorConfig + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "S2MelConfig": + dit_d = d.get("DiT", {}) + wv_d = d.get("wavenet", {}) + dit = DiTConfig( + hidden_dim=int(dit_d.get("hidden_dim", 512)), + num_heads=int(dit_d.get("num_heads", 8)), + depth=int(dit_d.get("depth", 13)), + in_channels=int(dit_d.get("in_channels", 80)), + content_dim=int(dit_d.get("content_dim", 512)), + style_dim=int(d.get("style_encoder", {}).get("dim", 192)), + is_causal=bool(dit_d.get("is_causal", False)), + long_skip_connection=bool(dit_d.get("long_skip_connection", True)), + uvit_skip_connection=bool(dit_d.get("uvit_skip_connection", True)), + final_layer_type=str(dit_d.get("final_layer_type", "wavenet")), + wavenet_hidden_dim=int(wv_d.get("hidden_dim", 512)), + wavenet_num_layers=int(wv_d.get("num_layers", 8)), + wavenet_kernel_size=int(wv_d.get("kernel_size", 5)), + wavenet_dilation_rate=int(wv_d.get("dilation_rate", 1)), + ) + + lr_d = d.get("length_regulator", {}) + lr = InterpolateRegulatorConfig( + channels=int(lr_d.get("channels", 512)), + sampling_ratios=tuple(lr_d.get("sampling_ratios", [1, 1, 1, 1])), + in_channels=int(lr_d.get("in_channels", 1024)), + out_channels=int(lr_d.get("channels", 512)), + groups=1, + ) + return cls(dit=dit, length_regulator=lr) + + +class S2MelModel(nn.Module): + def __init__(self, cfg: S2MelConfig): + super().__init__() + self.cfg = cfg + + self.cfm = CFM(cfg.dit) + self.length_regulator = InterpolateRegulator(cfg.length_regulator) + self.gpt_layer = [ + nn.Linear(1280, 256), + nn.Linear(256, 128), + nn.Linear(128, 1024), + ] + + def project_gpt_latent(self, x): + h = x + for layer in self.gpt_layer: + h = layer(h) + return h diff --git a/mlx_audio/tts/models/indextts2/s2mel_dit.py b/mlx_audio/tts/models/indextts2/s2mel_dit.py new file mode 100644 index 000000000..99a1627c7 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/s2mel_dit.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from .s2mel_gpt_fast import GPTFastArgs, GPTFastTransformer +from .s2mel_utils import sequence_mask +from .s2mel_wavenet import WN + + +def _mish(x: mx.array) -> mx.array: + return x * mx.tanh(mx.log1p(mx.exp(x))) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): + super().__init__() + self.mlp = [ + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ] + self.frequency_embedding_size = frequency_embedding_size + self.max_period = 10000 + self.scale = 1000 + + half = frequency_embedding_size // 2 + freqs = mx.exp( + -mx.log(mx.array(self.max_period, dtype=mx.float32)) + * (mx.arange(half, dtype=mx.float32) / half) + ) + self.freqs = freqs + + def timestep_embedding(self, t: mx.array) -> mx.array: + args = self.scale * t[:, None].astype(mx.float32) * self.freqs[None] + emb = mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1) + if self.frequency_embedding_size % 2: + emb = mx.concatenate([emb, mx.zeros((emb.shape[0], 1), dtype=emb.dtype)], axis=-1) + return emb + + def __call__(self, t: mx.array) -> mx.array: + x = self.timestep_embedding(t) + for layer in self.mlp: + x = layer(x) + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = [ + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ] + + def __call__(self, x: mx.array, c: mx.array) -> mx.array: + h = c + for layer in self.adaLN_modulation: + h = layer(h) + shift, scale = mx.split(h, 2, axis=-1) + x = self.norm_final(x) + x = x * (1 + scale[:, None, :]) + shift[:, None, :] + return self.linear(x) + + +@dataclass +class DiTConfig: + hidden_dim: int = 512 + num_heads: int = 8 + depth: int = 13 + in_channels: int = 80 + content_dim: int = 512 + style_dim: int = 192 + is_causal: bool = False + long_skip_connection: bool = True + uvit_skip_connection: bool = True + final_layer_type: str = "wavenet" # wavenet or mlp + wavenet_hidden_dim: int = 512 + wavenet_num_layers: int = 8 + wavenet_kernel_size: int = 5 + wavenet_dilation_rate: int = 1 + + +class DiT(nn.Module): + def __init__(self, cfg: DiTConfig): + super().__init__() + self.cfg = cfg + + gpt_cfg = GPTFastArgs( + block_size=16384, + n_layer=cfg.depth, + n_head=cfg.num_heads, + dim=cfg.hidden_dim, + head_dim=cfg.hidden_dim // cfg.num_heads, + n_local_heads=cfg.num_heads, + intermediate_size=int(2 * (4 * cfg.hidden_dim) / 3), + uvit_skip_connection=cfg.uvit_skip_connection, + ) + gpt_cfg.intermediate_size = 1536 # match checkpoint + self.transformer = GPTFastTransformer(gpt_cfg) + + self.in_channels = cfg.in_channels + + self.x_embedder = nn.Linear(cfg.in_channels, cfg.hidden_dim, bias=True) + + # Present in torch checkpoints but unused for continuous conditioning in IndexTTS2. + self.cond_embedder = nn.Embedding(1024, cfg.hidden_dim) + self.content_mask_embedder = nn.Embedding(1, cfg.hidden_dim) + + self.cond_projection = nn.Linear(cfg.content_dim, cfg.hidden_dim, bias=True) + self.t_embedder = TimestepEmbedder(cfg.hidden_dim) + + # (x + prompt_x + cond) + style + self.cond_x_merge_linear = nn.Linear( + cfg.hidden_dim + cfg.in_channels * 2 + cfg.style_dim, + cfg.hidden_dim, + bias=True, + ) + + self.long_skip_connection = cfg.long_skip_connection + if self.long_skip_connection: + self.skip_linear = nn.Linear(cfg.hidden_dim + cfg.in_channels, cfg.hidden_dim, bias=True) + + self.final_layer_type = cfg.final_layer_type + if self.final_layer_type == "wavenet": + self.t_embedder2 = TimestepEmbedder(cfg.wavenet_hidden_dim) + self.conv1 = nn.Linear(cfg.hidden_dim, cfg.wavenet_hidden_dim, bias=True) + self.conv2 = nn.Conv1d(cfg.wavenet_hidden_dim, cfg.in_channels, kernel_size=1) + self.wavenet = WN( + hidden_channels=cfg.wavenet_hidden_dim, + kernel_size=cfg.wavenet_kernel_size, + dilation_rate=cfg.wavenet_dilation_rate, + n_layers=cfg.wavenet_num_layers, + gin_channels=cfg.wavenet_hidden_dim, + p_dropout=0.0, + ) + self.final_layer = FinalLayer(cfg.wavenet_hidden_dim, cfg.wavenet_hidden_dim) + self.res_projection = nn.Linear(cfg.hidden_dim, cfg.wavenet_hidden_dim, bias=True) + else: + self.final_mlp = nn.Sequential( + nn.Linear(cfg.hidden_dim, cfg.hidden_dim), + nn.SiLU(), + nn.Linear(cfg.hidden_dim, cfg.in_channels), + ) + + self.input_pos = mx.arange(16384) + + def __call__( + self, + x: mx.array, + prompt_x: mx.array, + x_lens: mx.array, + t: mx.array, + style: mx.array, + cond: mx.array, + ) -> mx.array: + # x, prompt_x: (B, C, T) + # cond: (B, T, content_dim) + B, C, T = x.shape + + t1 = self.t_embedder(t.astype(mx.float32)) # (B, D) + cond_proj = self.cond_projection(cond) + + x_t = x.transpose(0, 2, 1) # (B, T, C) + p_t = prompt_x.transpose(0, 2, 1) + x_in = mx.concatenate([x_t, p_t, cond_proj, mx.repeat(style[:, None, :], T, axis=1)], axis=-1) + x_in = self.cond_x_merge_linear(x_in) + + # Attention mask (non-causal) + x_mask = sequence_mask(x_lens, max_length=int(T)).astype(mx.bool_) + key_mask = x_mask[:, None, None, :] # (B,1,1,T) + attn_mask = mx.where(key_mask, 0.0, -1e9).astype(mx.float32) + attn_mask = mx.broadcast_to(attn_mask, (B, 1, T, T)) + + input_pos = self.input_pos[:T] + x_res = self.transformer(x_in, t1, input_pos, attn_mask) + + if self.long_skip_connection: + x_res = self.skip_linear(mx.concatenate([x_res, x_t], axis=-1)) + + if self.final_layer_type == "wavenet": + h = self.conv1(x_res) # (B, T, H) + t2 = self.t_embedder2(t.astype(mx.float32)) + g = t2[:, None, :] # (B,1,H) + x_mask_c = x_mask[:, :, None].astype(h.dtype) + h = self.wavenet(h, x_mask_c, g=g) + self.res_projection(x_res) + h = self.final_layer(h, t1) + y = self.conv2(h) + else: + y = self.final_mlp(x_res) + return y.transpose(0, 2, 1) diff --git a/mlx_audio/tts/models/indextts2/s2mel_flow_matching.py b/mlx_audio/tts/models/indextts2/s2mel_flow_matching.py new file mode 100644 index 000000000..079fdcefa --- /dev/null +++ b/mlx_audio/tts/models/indextts2/s2mel_flow_matching.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx + +import mlx.nn as nn + +from .s2mel_dit import DiT, DiTConfig + + +class CFM(nn.Module): + def __init__(self, dit_cfg: DiTConfig): + super().__init__() + self.sigma_min = 1e-6 + self.in_channels = dit_cfg.in_channels + self.estimator = DiT(dit_cfg) + + def inference( + self, + mu: mx.array, + x_lens: mx.array, + prompt: mx.array, + style: mx.array, + f0: Optional[mx.array], + n_timesteps: int, + temperature: float = 1.0, + inference_cfg_rate: float = 0.7, + ) -> mx.array: + # mu: (B, T, 512) + B, T, _ = mu.shape + z = mx.random.normal((B, self.in_channels, T)).astype(mx.float32) * float(temperature) + t_span = mx.linspace(0.0, 1.0, n_timesteps + 1) + return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate) + + def solve_euler( + self, + x: mx.array, + x_lens: mx.array, + prompt: mx.array, + mu: mx.array, + style: mx.array, + f0: Optional[mx.array], + t_span: mx.array, + inference_cfg_rate: float, + ) -> mx.array: + del f0 + prompt_len = prompt.shape[-1] + + prompt_x = mx.zeros_like(x) + prompt_x[..., :prompt_len] = prompt[..., :prompt_len] + + if prompt_len > 0: + x = mx.concatenate([mx.zeros_like(x[..., :prompt_len]), x[..., prompt_len:]], axis=-1) + + t = t_span[0] + for step in range(1, t_span.shape[0]): + dt = t_span[step] - t_span[step - 1] + + if inference_cfg_rate > 0: + zeros_prompt = mx.zeros_like(prompt_x) + zeros_style = mx.zeros_like(style) + zeros_mu = mx.zeros_like(mu) + + stacked_prompt_x = mx.concatenate([prompt_x, zeros_prompt], axis=0) + stacked_style = mx.concatenate([style, zeros_style], axis=0) + stacked_mu = mx.concatenate([mu, zeros_mu], axis=0) + stacked_x = mx.concatenate([x, x], axis=0) + stacked_t = mx.concatenate([mx.array([t]), mx.array([t])], axis=0) + stacked_x_lens = mx.concatenate([x_lens, x_lens], axis=0) + + stacked_d = self.estimator( + stacked_x, + stacked_prompt_x, + stacked_x_lens, + stacked_t, + stacked_style, + stacked_mu, + ) + dphi_dt, cfg_dphi_dt = mx.split(stacked_d, 2, axis=0) + dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt + else: + dphi_dt = self.estimator(x, prompt_x, x_lens, mx.array([t]), style, mu) + + x = x + dt * dphi_dt + t = t + dt + if prompt_len > 0: + x = mx.concatenate([mx.zeros_like(x[..., :prompt_len]), x[..., prompt_len:]], axis=-1) + + return x diff --git a/mlx_audio/tts/models/indextts2/s2mel_gpt_fast.py b/mlx_audio/tts/models/indextts2/s2mel_gpt_fast.py new file mode 100644 index 000000000..58c9817cb --- /dev/null +++ b/mlx_audio/tts/models/indextts2/s2mel_gpt_fast.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +def _find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +def _apply_rotary_emb(x: mx.array, freqs: mx.array) -> mx.array: + # x: (B, L, H, D) + # freqs: (L, D/2, 2) with real/imag + xshaped = x.astype(mx.float32).reshape(*x.shape[:-1], -1, 2) + freqs = freqs.reshape(1, xshaped.shape[1], 1, xshaped.shape[3], 2) + re = xshaped[..., 0] * freqs[..., 0] - xshaped[..., 1] * freqs[..., 1] + im = xshaped[..., 1] * freqs[..., 0] + xshaped[..., 0] * freqs[..., 1] + out = mx.stack([re, im], axis=-1).reshape(x.shape) + return out.astype(x.dtype) + + +def _precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> mx.array: + freqs = 1.0 / (base ** (mx.arange(0, n_elem, 2, dtype=mx.float32) / n_elem)) + t = mx.arange(seq_len, dtype=mx.float32) + freqs = mx.outer(t, freqs) + return mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1).astype(mx.float32) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = mx.ones((dim,), dtype=mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + x_f = x.astype(mx.float32) + denom = mx.rsqrt(mx.mean(x_f * x_f, axis=-1, keepdims=True) + self.eps) + y = x_f * denom + return (y * self.weight).astype(x.dtype) + + +class AdaptiveLayerNorm(nn.Module): + def __init__(self, d_model: int, norm: nn.Module): + super().__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + + def __call__(self, x: mx.array, embedding: Optional[mx.array] = None) -> mx.array: + if embedding is None: + return self.norm(x) + weight, bias = mx.split(self.project_layer(embedding), 2, axis=-1) + return weight[:, None, :] * self.norm(x) + bias[:, None, :] + + +@dataclass +class GPTFastArgs: + block_size: int = 16384 + n_layer: int = 13 + n_head: int = 8 + dim: int = 512 + head_dim: int = 64 + n_local_heads: int = 8 + intermediate_size: int = 1536 + rope_base: float = 10000 + norm_eps: float = 1e-5 + uvit_skip_connection: bool = True + time_as_token: bool = False + + +class GPTFastAttention(nn.Module): + def __init__(self, config: GPTFastArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.n_head = config.n_head + self.n_local_heads = config.n_local_heads + self.head_dim = config.head_dim + self.scale = self.head_dim**-0.5 + + def __call__( + self, + x: mx.array, + freqs_cis: mx.array, + mask: mx.array, + ) -> mx.array: + B, L, _ = x.shape + kv_size = self.n_local_heads * self.head_dim + q, k, v = mx.split(self.wqkv(x), [kv_size, 2 * kv_size], axis=-1) + + q = q.reshape(B, L, self.n_head, self.head_dim) + k = k.reshape(B, L, self.n_local_heads, self.head_dim) + v = v.reshape(B, L, self.n_local_heads, self.head_dim) + + q = _apply_rotary_emb(q, freqs_cis) + k = _apply_rotary_emb(k, freqs_cis) + + q = q.transpose(0, 2, 1, 3) + k = k.transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + + if self.n_local_heads < self.n_head: + n_rep = self.n_head // self.n_local_heads + k = mx.repeat(k, n_rep, axis=1) + v = mx.repeat(v, n_rep, axis=1) + + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) + y = y.transpose(0, 2, 1, 3).reshape(B, L, self.n_head * self.head_dim) + return self.wo(y) + + +class GPTFastFeedForward(nn.Module): + def __init__(self, config: GPTFastArgs): + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + return self.w2(nn.silu(self.w1(x)) * self.w3(x)) + + +class GPTFastBlock(nn.Module): + def __init__(self, config: GPTFastArgs): + super().__init__() + self.attention = GPTFastAttention(config) + self.feed_forward = GPTFastFeedForward(config) + self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + self.attention_norm = AdaptiveLayerNorm( + config.dim, RMSNorm(config.dim, eps=config.norm_eps) + ) + + if config.uvit_skip_connection: + self.skip_in_linear = nn.Linear(config.dim * 2, config.dim) + self.uvit_skip_connection = True + else: + self.uvit_skip_connection = False + + self.time_as_token = config.time_as_token + + def __call__( + self, + x: mx.array, + c: mx.array, + input_pos: mx.array, + freqs_cis: mx.array, + mask: mx.array, + skip_in_x: Optional[mx.array] = None, + ) -> mx.array: + c_in = None if self.time_as_token else c + if self.uvit_skip_connection and skip_in_x is not None: + x = self.skip_in_linear(mx.concatenate([x, skip_in_x], axis=-1)) + + h = x + self.attention(self.attention_norm(x, c_in), freqs_cis, mask) + out = h + self.feed_forward(self.ffn_norm(h, c_in)) + return out + + +class GPTFastTransformer(nn.Module): + def __init__(self, config: GPTFastArgs): + super().__init__() + self.config = config + self.layers = [GPTFastBlock(config) for _ in range(config.n_layer)] + self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + self.freqs_cis = _precompute_freqs_cis(config.block_size, config.head_dim, int(config.rope_base)) + + self.uvit_skip_connection = config.uvit_skip_connection + if self.uvit_skip_connection: + self.layers_emit_skip = [i for i in range(config.n_layer) if i < config.n_layer // 2] + self.layers_receive_skip = [i for i in range(config.n_layer) if i > config.n_layer // 2] + else: + self.layers_emit_skip = [] + self.layers_receive_skip = [] + + def __call__( + self, + x: mx.array, + c: mx.array, + input_pos: mx.array, + mask: mx.array, + ) -> mx.array: + freqs = self.freqs_cis[input_pos] + skip_stack = [] + for i, layer in enumerate(self.layers): + if self.uvit_skip_connection and i in self.layers_receive_skip: + skip_in_x = skip_stack.pop() + else: + skip_in_x = None + x = layer(x, c, input_pos, freqs, mask, skip_in_x=skip_in_x) + if self.uvit_skip_connection and i in self.layers_emit_skip: + skip_stack.append(x) + return self.norm(x, c) diff --git a/mlx_audio/tts/models/indextts2/s2mel_length_regulator.py b/mlx_audio/tts/models/indextts2/s2mel_length_regulator.py new file mode 100644 index 000000000..e3804b42b --- /dev/null +++ b/mlx_audio/tts/models/indextts2/s2mel_length_regulator.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Sequence, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .s2mel_utils import sequence_mask + + +@dataclass +class InterpolateRegulatorConfig: + channels: int = 512 + sampling_ratios: Tuple[int, ...] = (1, 1, 1, 1) + in_channels: int = 1024 + out_channels: int = 512 + groups: int = 1 + + +def _nearest_interpolate_1d(x: mx.array, out_len: int) -> mx.array: + # x: (T, C) + in_len = int(x.shape[0]) + if in_len == out_len: + return x + if in_len <= 1: + return mx.broadcast_to(x[:1, :], (out_len, x.shape[1])) + + idx = (mx.floor(mx.arange(out_len, dtype=mx.float32) * (in_len / out_len))).astype(mx.int32) + idx = mx.clip(idx, 0, in_len - 1) + return x[idx, :] + + +class InterpolateRegulator(nn.Module): + def __init__(self, cfg: InterpolateRegulatorConfig): + super().__init__() + self.cfg = cfg + self.interpolate = len(cfg.sampling_ratios) > 0 + + self.model = [] + if self.interpolate: + for _ in cfg.sampling_ratios: + self.model.append(nn.Conv1d(cfg.channels, cfg.channels, 3, 1, 1)) + self.model.append(nn.GroupNorm(cfg.groups, cfg.channels)) + self.model.append(nn.Mish()) + + self.model.append(nn.Conv1d(cfg.channels, cfg.out_channels, 1, 1)) + + # Unused in IndexTTS2 continuous mode but present in checkpoints. + self.embedding = nn.Embedding(2048, cfg.channels) + + self.content_in_proj = nn.Linear(cfg.in_channels, cfg.channels) + + self.mask_token = mx.zeros((1, cfg.channels), dtype=mx.float32) + + def __call__( + self, + x: mx.array, + *, + ylens: mx.array, + f0: Optional[mx.array] = None, + n_quantizers: Optional[int] = None, + ): + del f0, n_quantizers + # x: (B, T, in_channels) + if x.ndim != 3: + raise ValueError(f"Expected (B,T,C), got {x.shape}") + + B, T, _ = x.shape + x = self.content_in_proj(x) # (B, T, channels) + + out_len = int(mx.max(ylens).item()) + if self.interpolate: + xs = [] + for i in range(B): + xs.append(_nearest_interpolate_1d(x[i], out_len)) + x = mx.stack(xs, axis=0) # (B, out_len, C) + else: + x = x[:, :out_len, :] + + h = x + for layer in self.model: + h = layer(h) + out = h # (B, out_len, out_channels) + mask = sequence_mask(ylens, max_length=out_len).astype(out.dtype)[:, :, None] + return out * mask, ylens, None, None, None diff --git a/mlx_audio/tts/models/indextts2/s2mel_utils.py b/mlx_audio/tts/models/indextts2/s2mel_utils.py new file mode 100644 index 000000000..0cbcc0a5f --- /dev/null +++ b/mlx_audio/tts/models/indextts2/s2mel_utils.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import mlx.core as mx +import mlx.nn as nn + + +def sequence_mask(lengths: mx.array, max_length: int | None = None) -> mx.array: + # lengths: (B,) + if max_length is None: + max_length = int(mx.max(lengths).item()) + rng = mx.arange(max_length) + return rng[None, :] < lengths[:, None] + + +def fused_add_tanh_sigmoid_multiply( + input_a: mx.array, input_b: mx.array, n_channels: int +) -> mx.array: + in_act = input_a + input_b + t_act_part = in_act[:, :n_channels, :] + s_act_part = in_act[:, n_channels:, :] + return mx.tanh(t_act_part) * nn.sigmoid(s_act_part) diff --git a/mlx_audio/tts/models/indextts2/s2mel_wavenet.py b/mlx_audio/tts/models/indextts2/s2mel_wavenet.py new file mode 100644 index 000000000..c847649a2 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/s2mel_wavenet.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from .s2mel_utils import fused_add_tanh_sigmoid_multiply + + +def _conv1d_weightnorm_strip(weight_g: mx.array, weight_v: mx.array, eps: float = 1e-8) -> mx.array: + # weightnorm: w = g * v / ||v|| + v = weight_v + # Norm over (in, k) dims for conv1d v shaped (O, I, K) in torch; converter should transpose first. + # Here we assume already MLX conv layout (O, K, I). + norm = mx.sqrt(mx.sum(v * v, axis=(1, 2), keepdims=True) + eps) + return v * (weight_g.reshape(-1, 1, 1) / norm) + + +class WNConv1d(nn.Module): + """Conv1d wrapper that optionally accepts already-stripped weights.""" + + def __init__(self, in_ch: int, out_ch: int, kernel_size: int, *, dilation: int = 1, padding: int = 0, bias: bool = True): + super().__init__() + self.conv = nn.Conv1d( + in_ch, + out_ch, + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + bias=bias, + ) + + def __call__(self, x: mx.array) -> mx.array: + return self.conv(x) + + +class WN(nn.Module): + def __init__( + self, + hidden_channels: int, + kernel_size: int, + dilation_rate: int, + n_layers: int, + gin_channels: int = 0, + p_dropout: float = 0.0, + ): + super().__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.in_layers = [] + self.res_skip_layers = [] + + if gin_channels != 0: + self.cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + self.in_layers.append( + nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + ) + + res_skip_channels = 2 * hidden_channels if i < n_layers - 1 else hidden_channels + self.res_skip_layers.append(nn.Conv1d(hidden_channels, res_skip_channels, 1)) + + def __call__(self, x: mx.array, x_mask: mx.array, g: Optional[mx.array] = None) -> mx.array: + # x: (B, T, C), x_mask: (B, T, 1), g: (B, 1, gin) + output = mx.zeros_like(x) + n_ch = self.hidden_channels + + if g is not None and self.gin_channels != 0: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * n_ch + g_l = g[:, :, cond_offset : cond_offset + 2 * n_ch] + else: + g_l = mx.zeros_like(x_in) + + acts = fused_add_tanh_sigmoid_multiply( + x_in.transpose(0, 2, 1), g_l.transpose(0, 2, 1), n_ch + ).transpose(0, 2, 1) + res_skip = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res = res_skip[:, :, :n_ch] + x = (x + res) * x_mask + output = output + res_skip[:, :, n_ch:] + else: + output = output + res_skip + + return output * x_mask diff --git a/mlx_audio/tts/models/indextts2/semantic_codec.py b/mlx_audio/tts/models/indextts2/semantic_codec.py new file mode 100644 index 000000000..702e33f44 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/semantic_codec.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + +from mlx_audio.codec.models.bigvgan.conv import WNConv1d +from mlx_audio.codec.models.vocos.vocos import VocosBackbone + + +def _l2_normalize(x: mx.array, axis: int = -1, eps: float = 1e-12) -> mx.array: + denom = mx.sqrt(mx.maximum(mx.sum(x * x, axis=axis, keepdims=True), eps)) + return x / denom + + +class FactorizedVectorQuantize(nn.Module): + """MLX port of Amphion FactorizedVectorQuantize. + + Expects latents shaped (B, T, D). + """ + + def __init__( + self, + input_dim: int, + codebook_size: int, + codebook_dim: int, + use_l2_normlize: bool = True, + ): + super().__init__() + self.input_dim = input_dim + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.use_l2_normlize = use_l2_normlize + + if self.input_dim != self.codebook_dim: + self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) + self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) + else: + self.in_project = nn.Identity() + self.out_project = nn.Identity() + + self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) + + def decode_latents(self, latents: mx.array) -> Tuple[mx.array, mx.array]: + # latents: (B, T, D) + B, T, D = latents.shape + enc = latents.reshape(B * T, D) + + codebook = self.codebook.weight # (K, D) + if self.use_l2_normlize: + enc = _l2_normalize(enc, axis=-1) + codebook_n = _l2_normalize(codebook, axis=-1) + else: + codebook_n = codebook + + # Squared euclidean distance + dist = ( + mx.sum(enc * enc, axis=1, keepdims=True) + - 2.0 * (enc @ codebook_n.T) + + mx.sum(codebook_n * codebook_n, axis=1)[None, :] + ) # (B*T, K) + indices = mx.argmax(-dist, axis=1).reshape(B, T) + + z_q = self.codebook(indices) # (B, T, D) + return z_q, indices + + def __call__(self, z: mx.array) -> Tuple[mx.array, mx.array]: + # z: (B, T, D) + z_e = self.in_project(z) + z_q, indices = self.decode_latents(z_e) + z_q = self.out_project(z_q) + return z_q, indices + + def vq2emb(self, vq: mx.array, *, out_proj: bool = True) -> mx.array: + emb = self.codebook(vq) + if out_proj: + emb = self.out_project(emb) + return emb + + +class ResidualVQ(nn.Module): + """MLX port of Amphion ResidualVQ (inference-only). + + Expects latents shaped (B, T, D). + """ + + def __init__( + self, + input_dim: int, + num_quantizers: int, + codebook_size: int, + codebook_dim: int, + use_l2_normlize: bool = True, + ): + super().__init__() + self.input_dim = input_dim + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.quantizers = [ + FactorizedVectorQuantize( + input_dim=input_dim, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + use_l2_normlize=use_l2_normlize, + ) + for _ in range(num_quantizers) + ] + + def __call__(self, z: mx.array, *, n_quantizers: Optional[int] = None): + if n_quantizers is None: + n_quantizers = self.num_quantizers + + quantized_out = mx.zeros_like(z) + residual = z + all_indices = [] + + for i, q in enumerate(self.quantizers): + if i >= n_quantizers: + break + z_q_i, idx_i = q(residual) + quantized_out = quantized_out + z_q_i + residual = residual - z_q_i + all_indices.append(idx_i) + + all_indices = mx.stack(all_indices, axis=0) # (N, B, T) + return quantized_out, all_indices + + def vq2emb(self, vq: mx.array, *, n_quantizers: Optional[int] = None) -> mx.array: + if n_quantizers is None: + n_quantizers = self.num_quantizers + + out = 0.0 + for i, q in enumerate(self.quantizers): + if i >= n_quantizers: + break + out = out + q.vq2emb(vq[i]) + return out + + +@dataclass +class RepCodecConfig: + codebook_size: int = 8192 + hidden_size: int = 1024 + codebook_dim: int = 8 + vocos_dim: int = 384 + vocos_intermediate_dim: int = 2048 + vocos_num_layers: int = 12 + num_quantizers: int = 1 + downsample_scale: int = 1 + + +class RepCodec(nn.Module): + """MLX port of MaskGCT RepCodec (semantic codec).""" + + def __init__(self, cfg: RepCodecConfig): + super().__init__() + self.cfg = cfg + + self.codebook_size = cfg.codebook_size + self.codebook_dim = cfg.codebook_dim + self.hidden_size = cfg.hidden_size + + if cfg.downsample_scale and cfg.downsample_scale > 1: + self.down = nn.Conv1d( + cfg.hidden_size, cfg.hidden_size, kernel_size=3, stride=2, padding=1 + ) + self.up = nn.Conv1d( + cfg.hidden_size, cfg.hidden_size, kernel_size=3, stride=1, padding=1 + ) + else: + self.down = None + self.up = None + + self.encoder = [ + VocosBackbone( + input_channels=cfg.hidden_size, + dim=cfg.vocos_dim, + intermediate_dim=cfg.vocos_intermediate_dim, + num_layers=cfg.vocos_num_layers, + adanorm_num_embeddings=None, + ), + nn.Linear(cfg.vocos_dim, cfg.hidden_size), + ] + self.decoder = [ + VocosBackbone( + input_channels=cfg.hidden_size, + dim=cfg.vocos_dim, + intermediate_dim=cfg.vocos_intermediate_dim, + num_layers=cfg.vocos_num_layers, + adanorm_num_embeddings=None, + ), + nn.Linear(cfg.vocos_dim, cfg.hidden_size), + ] + + self.quantizer = ResidualVQ( + input_dim=cfg.hidden_size, + num_quantizers=cfg.num_quantizers, + codebook_size=cfg.codebook_size, + codebook_dim=cfg.codebook_dim, + use_l2_normlize=True, + ) + + def _encode(self, x: mx.array) -> mx.array: + # x: (B, T, C) + y = self.encoder[0](x) + y = self.encoder[1](y) + return y + + def _decode(self, x: mx.array) -> mx.array: + y = self.decoder[0](x) + y = self.decoder[1](y) + return y + + def quantize(self, x: mx.array) -> Tuple[mx.array, mx.array]: + # Downsample (optional) + if self.down is not None: + x = self.down(x) + x = nn.gelu(x) + + x = self._encode(x) + quantized_out, all_indices = self.quantizer(x) + # Match torch method return: indices squeezed when N==1 + if all_indices.shape[0] == 1: + return all_indices[0], quantized_out + return all_indices, quantized_out + + def vq2emb(self, vq: mx.array, *, n_quantizers: Optional[int] = None) -> mx.array: + return self.quantizer.vq2emb(vq, n_quantizers=n_quantizers) + + def sanitize(self, weights: dict[str, mx.array]) -> dict[str, mx.array]: + """Best-effort sanitizer for PyTorch -> MLX weight layout. + + - Conv1d: (O, I, K) -> (O, K, I) + - Depthwise conv weights inside VocosBackbone already handled by shape check. + """ + curr = dict(tree_flatten(self.parameters())) + out = {} + for k, v in weights.items(): + if k not in curr: + out[k] = v + continue + + if v.ndim == 3 and curr[k].ndim == 3 and v.shape != curr[k].shape: + # Torch conv1d (O,I,K) -> MLX (O,K,I) + if v.shape[0] == curr[k].shape[0] and v.shape[1] == curr[k].shape[2]: + v = v.transpose(0, 2, 1) + + out[k] = v + return out diff --git a/mlx_audio/tts/models/indextts2/unifiedvoice.py b/mlx_audio/tts/models/indextts2/unifiedvoice.py new file mode 100644 index 000000000..fd206f27b --- /dev/null +++ b/mlx_audio/tts/models/indextts2/unifiedvoice.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import math +import re +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +import sentencepiece as spm +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.cache import KVCache +from mlx_lm.sample_utils import make_sampler + +from mlx_audio.tts.models.indextts.gpt2 import GPT2Model +from mlx_lm.models.gpt2 import ModelArgs as GPT2Args + +from .unifiedvoice_conformer import ConformerEncoder, ConformerEncoderConfig + + +def _pad_left(x: mx.array, pad: int, value: float = 0.0) -> mx.array: + if pad <= 0: + return x + return mx.pad(x, ((0, 0), (pad, 0)), mode="constant", constant_values=value) + + +class LearnedPositionEmbeddings(nn.Module): + def __init__(self, seq_len: int, model_dim: int): + super().__init__() + self.emb = nn.Embedding(seq_len, model_dim) + + def __call__(self, x: mx.array) -> mx.array: + sl = x.shape[1] + return self.emb(mx.arange(sl)) + + def get_fixed_embedding(self, ind: int) -> mx.array: + return self.emb(mx.array([ind], dtype=mx.int32))[None, :, :] + + +class GEGLU(nn.Module): + def __call__(self, x: mx.array) -> mx.array: + x, gate = mx.split(x, 2, axis=-1) + return nn.gelu(gate) * x + + +class RMSNormGamma(nn.Module): + def __init__(self, dim: int, eps: float = 1e-8): + super().__init__() + self.eps = eps + self.gamma = mx.ones((dim,), dtype=mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + x_f = x.astype(mx.float32) + denom = mx.rsqrt(mx.mean(x_f * x_f, axis=-1, keepdims=True) + self.eps) + return (x_f * denom * self.gamma).astype(x.dtype) + + +class PerceiverAttend(nn.Module): + def __init__(self, *, dropout: float = 0.0, causal: bool = False): + super().__init__() + self.dropout = dropout + self.causal = causal + + def __call__(self, q: mx.array, k: mx.array, v: mx.array, mask: Optional[mx.array] = None) -> mx.array: + # q,k,v: (B, H, L, D) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=q.shape[-1] ** -0.5, mask=mask) + return out + + +class PerceiverAttention(nn.Module): + def __init__( + self, + dim: int, + *, + dim_context: Optional[int] = None, + dim_head: int = 64, + heads: int = 8, + cross_attn_include_queries: bool = True, + ): + super().__init__() + self.heads = heads + self.dim_head = dim_head + self.cross_attn_include_queries = cross_attn_include_queries + dim_inner = dim_head * heads + dim_context = dim if dim_context is None else dim_context + + self.to_q = nn.Linear(dim, dim_inner, bias=False) + self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) + self.to_out = nn.Linear(dim_inner, dim, bias=False) + self.attend = PerceiverAttend() + + def __call__(self, x: mx.array, context: mx.array, mask: Optional[mx.array] = None) -> mx.array: + # x: (B, N, D), context: (B, M, Dctx), mask: (B, M) + if self.cross_attn_include_queries: + context = mx.concatenate([x, context], axis=-2) + if mask is not None: + qmask = mx.ones((mask.shape[0], x.shape[1]), dtype=mask.dtype) + mask = mx.concatenate([qmask, mask], axis=-1) + + q = self.to_q(x) + k, v = mx.split(self.to_kv(context), 2, axis=-1) + + B, N, _ = q.shape + H = self.heads + + q = q.reshape(B, N, H, self.dim_head).transpose(0, 2, 1, 3) + k = k.reshape(B, context.shape[1], H, self.dim_head).transpose(0, 2, 1, 3) + v = v.reshape(B, context.shape[1], H, self.dim_head).transpose(0, 2, 1, 3) + + attn_mask = None + if mask is not None: + attn_mask = (1.0 - mask.astype(mx.float32))[:, None, None, :] * (-1e9) + + out = self.attend(q, k, v, mask=attn_mask) + out = out.transpose(0, 2, 1, 3).reshape(B, N, H * self.dim_head) + return self.to_out(out) + + +def _perceiver_ff(dim: int, mult: int = 4) -> nn.Module: + dim_inner = int(dim * mult * 2 / 3) + return [nn.Linear(dim, dim_inner * 2), GEGLU(), nn.Linear(dim_inner, dim)] + + +class PerceiverResampler(nn.Module): + def __init__( + self, + dim: int, + *, + depth: int = 2, + dim_context: Optional[int] = None, + num_latents: int = 32, + dim_head: int = 64, + heads: int = 8, + ff_mult: int = 4, + ): + super().__init__() + dim_context = dim if dim_context is None else dim_context + + self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() + self.latents = mx.random.normal((num_latents, dim)).astype(mx.float32) * 0.02 + + self.layers = [] + for _ in range(depth): + self.layers.append( + [ + PerceiverAttention( + dim=dim, + dim_context=dim, + dim_head=dim_head, + heads=heads, + cross_attn_include_queries=True, + ), + _perceiver_ff(dim, mult=ff_mult), + ] + ) + self.norm = RMSNormGamma(dim) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + B = x.shape[0] + x = self.proj_context(x) + latents = mx.broadcast_to(self.latents, (B, *self.latents.shape)) + for attn, ff in self.layers: + latents = attn(latents, x, mask=mask) + latents + h = latents + for layer in ff: + h = layer(h) + latents = h + latents + return self.norm(latents) + + +@dataclass +class UnifiedVoiceConfig: + model_dim: int + heads: int + layers: int + max_mel_tokens: int + max_text_tokens: int + number_text_tokens: int + number_mel_codes: int + start_mel_token: int + stop_mel_token: int + start_text_token: int + stop_text_token: int + condition_type: str + condition_module: Dict[str, Any] + emo_condition_module: Dict[str, Any] + condition_num_latent: int = 32 + max_conditioning_inputs: int = 1 + mel_length_compression: int = 1024 + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "UnifiedVoiceConfig": + return cls( + model_dim=int(d["model_dim"]), + heads=int(d["heads"]), + layers=int(d["layers"]), + max_mel_tokens=int(d["max_mel_tokens"]), + max_text_tokens=int(d["max_text_tokens"]), + number_text_tokens=int(d["number_text_tokens"]), + number_mel_codes=int(d["number_mel_codes"]), + start_mel_token=int(d["start_mel_token"]), + stop_mel_token=int(d["stop_mel_token"]), + start_text_token=int(d["start_text_token"]), + stop_text_token=int(d["stop_text_token"]), + condition_type=str(d["condition_type"]), + condition_module=dict(d.get("condition_module", {})), + emo_condition_module=dict(d.get("emo_condition_module", {})), + condition_num_latent=int(d.get("condition_num_latent", 32)), + max_conditioning_inputs=int(d.get("max_conditioning_inputs", 1)), + mel_length_compression=int(d.get("mel_length_compression", 1024)), + ) + + +class UnifiedVoice(nn.Module): + def __init__(self, cfg: UnifiedVoiceConfig, *, bpe_model: str): + super().__init__() + self.cfg = cfg + + self.tokenizer = spm.SentencePieceProcessor(model_file=bpe_model) + + self.text_embedding = nn.Embedding(cfg.number_text_tokens + 1, cfg.model_dim) + self.mel_embedding = nn.Embedding(cfg.number_mel_codes, cfg.model_dim) + self.mel_pos_embedding = LearnedPositionEmbeddings( + cfg.max_mel_tokens + 2 + cfg.max_conditioning_inputs, cfg.model_dim + ) + self.text_pos_embedding = LearnedPositionEmbeddings(cfg.max_text_tokens + 2, cfg.model_dim) + + self.text_head = nn.Linear(cfg.model_dim, cfg.number_text_tokens + 1) + self.mel_head = nn.Linear(cfg.model_dim, cfg.number_mel_codes) + + # Conditioning encoders (port of index-tts ConformerEncoder) + self.conditioning_encoder = ConformerEncoder( + ConformerEncoderConfig( + input_size=1024, + output_size=int(cfg.condition_module["output_size"]), + linear_units=int(cfg.condition_module["linear_units"]), + attention_heads=int(cfg.condition_module["attention_heads"]), + num_blocks=int(cfg.condition_module["num_blocks"]), + input_layer=str(cfg.condition_module["input_layer"]), + ) + ) + self.perceiver_encoder = PerceiverResampler( + cfg.model_dim, + dim_context=int(cfg.condition_module["output_size"]), + heads=int(cfg.condition_module["attention_heads"]), + ff_mult=int(cfg.condition_module.get("perceiver_mult", 2)), + num_latents=cfg.condition_num_latent, + ) + + self.emo_conditioning_encoder = ConformerEncoder( + ConformerEncoderConfig( + input_size=1024, + output_size=int(cfg.emo_condition_module["output_size"]), + linear_units=int(cfg.emo_condition_module["linear_units"]), + attention_heads=int(cfg.emo_condition_module["attention_heads"]), + num_blocks=int(cfg.emo_condition_module["num_blocks"]), + input_layer=str(cfg.emo_condition_module["input_layer"]), + ) + ) + self.emo_perceiver_encoder = PerceiverResampler( + 1024, + dim_context=int(cfg.emo_condition_module["output_size"]), + heads=int(cfg.emo_condition_module["attention_heads"]), + ff_mult=int(cfg.emo_condition_module.get("perceiver_mult", 2)), + num_latents=1, + ) + + self.emo_layer = nn.Linear(cfg.model_dim, cfg.model_dim) + self.emovec_layer = nn.Linear(1024, cfg.model_dim) + + self.speed_emb = nn.Embedding(2, cfg.model_dim) + self.final_norm = nn.LayerNorm(cfg.model_dim) + + self.gpt = GPT2Model( + GPT2Args( + "gpt2", + 1, + cfg.model_dim, + cfg.heads, + cfg.layers, + 1, + 1e-5, + 1, + ) + ) + # Patch GPT2 token/pos embeddings; we feed embeddings directly. + self.gpt.wte = nn.Identity() # type: ignore + self.gpt.wpe = nn.Identity() # type: ignore + + @staticmethod + def _pretokenize_text(text: str) -> str: + # Match official IndexTTS tokenizer pre-tokenizer behavior: + # split around CJK chars and uppercase non-CJK spans. + cjk_range = ( + r"([\u1100-\u11ff\u2e80-\ua4cf\ua840-\uD7AF\uF900-\uFAFF\uFE30-\uFE4F\uFF65-\uFFDC\U00020000-\U0002FFFF])" + ) + parts = re.split(cjk_range, text.strip()) + return " ".join(p.strip().upper() for p in parts if p.strip()) + + def encode_text(self, text: str) -> mx.array: + text = self._pretokenize_text(text) + ids = self.tokenizer.encode(text) + return mx.array(ids, dtype=mx.int32)[None, :] + + def get_conditioning(self, x: mx.array, x_lens: mx.array) -> mx.array: + hs, mask = self.conditioning_encoder(x, x_lens) + # Build perceiver mask: allow attending to context tokens only + return self.perceiver_encoder(hs, mask=mask[:, 0, :]) + + def get_emo_conditioning(self, x: mx.array, x_lens: mx.array) -> mx.array: + hs, mask = self.emo_conditioning_encoder(x, x_lens) + lat = self.emo_perceiver_encoder(hs, mask=mask[:, 0, :]) + return lat[:, 0, :] + + def get_emovec(self, emo_cond: mx.array, emo_lens: mx.array) -> mx.array: + v = self.get_emo_conditioning(emo_cond, emo_lens) + v = self.emovec_layer(v) + return self.emo_layer(v) + + def merge_emovec( + self, + spk_cond: mx.array, + emo_cond: mx.array, + spk_lens: mx.array, + emo_lens: mx.array, + *, + alpha: float = 1.0, + ) -> mx.array: + emo_vec = self.get_emovec(emo_cond, emo_lens) + base_vec = self.get_emovec(spk_cond, spk_lens) + return base_vec + float(alpha) * (emo_vec - base_vec) + + def forward_latent( + self, + speech_conditioning_latent: mx.array, + text_tokens: mx.array, + codes: mx.array, + emo_vec: mx.array, + ) -> mx.array: + # Build embeddings for forward pass similar to official. + # speech_conditioning_latent: (B, 32, D) + B = text_tokens.shape[0] + use_speed = mx.zeros((B,), dtype=mx.int32) + + duration_emb = self.speed_emb(use_speed) + duration_emb_half = self.speed_emb(mx.ones_like(use_speed)) + conds = mx.concatenate( + [speech_conditioning_latent + emo_vec[:, None, :], duration_emb_half[:, None, :], duration_emb[:, None, :]], + axis=1, + ) + + # Text: add stop token + text_tokens = mx.concatenate( + [text_tokens, mx.full((B, 1), self.cfg.stop_text_token, dtype=mx.int32)], + axis=1, + ) + text_inp = mx.concatenate( + [mx.full((B, 1), self.cfg.start_text_token, dtype=mx.int32), text_tokens], + axis=1, + ) + text_emb = self.text_embedding(text_inp) + self.text_pos_embedding(text_inp) + + # Codes: add stop token + codes = mx.concatenate( + [codes, mx.full((B, 1), self.cfg.stop_mel_token, dtype=mx.int32)], + axis=1, + ) + mel_inp = mx.concatenate( + [mx.full((B, 1), self.cfg.start_mel_token, dtype=mx.int32), codes], + axis=1, + ) + mel_emb = self.mel_embedding(mel_inp) + self.mel_pos_embedding(mel_inp) + + emb = mx.concatenate([conds, text_emb, mel_emb], axis=1) + mask = create_attention_mask(emb, cache=None) + hs = self.gpt(emb, mask=mask, cache=None) + hs = self.final_norm(hs[:, conds.shape[1] :, :]) + + # Return mel latent portion (strip the two extra tokens) + mel_lat = hs[:, text_emb.shape[1] : text_emb.shape[1] + mel_emb.shape[1], :] + return mel_lat[:, :-2, :] + + def inference_speech( + self, + speech_condition: mx.array, + text_tokens: mx.array, + emo_condition: Optional[mx.array] = None, + *, + alpha: float = 1.0, + top_p: float = 0.8, + top_k: int = 30, + temperature: float = 0.8, + max_generate_length: int = 1500, + repetition_penalty: float = 10.0, + ) -> Tuple[mx.array, mx.array]: + # speech_condition: (B, T, 1024) + if emo_condition is None: + emo_condition = speech_condition + + B = speech_condition.shape[0] + spk_lens = mx.array([speech_condition.shape[1]] * B, dtype=mx.int32) + emo_lens = mx.array([emo_condition.shape[1]] * B, dtype=mx.int32) + + speech_conditioning_latent = self.get_conditioning(speech_condition, spk_lens) + emo_vec = self.merge_emovec(speech_condition, emo_condition, spk_lens, emo_lens, alpha=alpha) + + use_speed = mx.zeros((B,), dtype=mx.int32) + duration_emb = self.speed_emb(use_speed) + duration_emb_half = self.speed_emb(mx.ones_like(use_speed)) + conds = mx.concatenate( + [speech_conditioning_latent + emo_vec[:, None, :], duration_emb_half[:, None, :], duration_emb[:, None, :]], + axis=1, + ) + + # Build text emb with start/stop + text_tokens = mx.concatenate( + [mx.full((B, 1), self.cfg.start_text_token, dtype=mx.int32), text_tokens], + axis=1, + ) + text_tokens = mx.concatenate( + [text_tokens, mx.full((B, 1), self.cfg.stop_text_token, dtype=mx.int32)], + axis=1, + ) + text_emb = self.text_embedding(text_tokens) + self.text_pos_embedding(text_tokens) + + # Start mel + cur = mx.full((B, 1), self.cfg.start_mel_token, dtype=mx.int32) + mel_emb0 = self.mel_embedding(cur) + self.mel_pos_embedding(cur) + + # First forward with full prefix + emb0 = mx.concatenate([conds, text_emb, mel_emb0], axis=1) + mask0 = create_attention_mask(emb0, cache=None) + cache = [KVCache() for _ in range(self.cfg.layers)] + hs = self.gpt(emb0, mask=mask0, cache=cache) + hs_last = self.final_norm(hs[:, -1:, :]) + logits = self.mel_head(hs_last)[:, 0, :] + + sampler = make_sampler(temp=temperature, top_p=top_p, top_k=top_k) + + def apply_repetition_penalty(logits_: mx.array, generated: list[mx.array]) -> mx.array: + if repetition_penalty is None or repetition_penalty <= 1.0 or len(generated) == 0: + return logits_ + # HF-style repetition penalty for batch size 1 (IndexTTS2 inference path). + # For this model we run B=1 in practice. + row = logits_[0].tolist() + seen = {int(t.item()) for t in generated} + p = float(repetition_penalty) + for tok in seen: + v = float(row[tok]) + row[tok] = v * p if v < 0 else v / p + return mx.array([row], dtype=logits_.dtype) + + tokens = [] + mel_pos = 1 + for step in range(max_generate_length): + logits_step = apply_repetition_penalty(logits, tokens) + next_tok = sampler(logits_step) + tokens.append(next_tok) + if int(next_tok.item()) == self.cfg.stop_mel_token: + break + + cur = next_tok.reshape(B, 1).astype(mx.int32) + pos = self.mel_pos_embedding.get_fixed_embedding(mel_pos) + if pos.shape[0] != B: + pos = mx.broadcast_to(pos, (B, pos.shape[1], pos.shape[2])) + emb = self.mel_embedding(cur) + pos + mask = create_attention_mask(emb, cache=cache) + hs = self.gpt(emb, mask=mask, cache=cache) + hs_last = self.final_norm(hs[:, -1:, :]) + logits = self.mel_head(hs_last)[:, 0, :] + mel_pos += 1 + + codes = mx.stack(tokens, axis=1) + return codes, speech_conditioning_latent diff --git a/mlx_audio/tts/models/indextts2/unifiedvoice_conformer.py b/mlx_audio/tts/models/indextts2/unifiedvoice_conformer.py new file mode 100644 index 000000000..ff33fa735 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/unifiedvoice_conformer.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +class RelPositionalEncoding(nn.Module): + def __init__(self, d_model: int, max_len: int = 5000): + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(d_model) + self.max_len = max_len + + position = mx.arange(max_len).astype(mx.float32)[:, None] + div_term = mx.exp( + mx.arange(0, d_model, 2, dtype=mx.float32) + * (-(math.log(10000.0) / d_model)) + ) + pe = mx.zeros((max_len, d_model), dtype=mx.float32) + pe[:, 0::2] = mx.sin(position * div_term) + pe[:, 1::2] = mx.cos(position * div_term) + self.pe = pe[None, :, :] + + def __call__(self, x: mx.array, offset: int = 0) -> Tuple[mx.array, mx.array]: + if offset + x.shape[1] > self.max_len: + # extend + self.__init__(self.d_model, max_len=offset + x.shape[1] + 1) + pos_emb = self.pe[:, offset : offset + x.shape[1]].astype(x.dtype) + return x * self.xscale, pos_emb + + +class Conv2dSubsampling2(nn.Module): + def __init__(self, idim: int, odim: int): + super().__init__() + self.conv = [nn.Conv2d(1, odim, 3, 2), nn.ReLU()] + self.out = [nn.Linear(odim * ((idim - 1) // 2), odim)] + self.pos_enc = RelPositionalEncoding(odim) + + def __call__( + self, x: mx.array, x_mask: mx.array, offset: int = 0 + ) -> Tuple[mx.array, mx.array, mx.array]: + # x: (B, T, F) + x = x[:, :, :, None] + for layer in self.conv: + x = layer(x) + b, t, f, c = x.shape + # Match torch path: (B, C, T, F) -> transpose(1,2) -> (B, T, C, F) -> flatten C*F. + # MLX conv is channel-last (B, T, F, C), so reorder to (B, T, C, F) before flatten. + x = x.transpose(0, 1, 3, 2).reshape(b, t, c * f) + for layer in self.out: + x = layer(x) + x, pos = self.pos_enc(x, offset) + # mask: (B, 1, T) -> (B,1,T') + return x, pos, x_mask[:, :, 2::2] + + +class MultiHeadedAttention(nn.Module): + def __init__(self, n_head: int, n_feat: int): + super().__init__() + assert n_feat % n_head == 0 + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + + def forward_qkv(self, query: mx.array, key: mx.array, value: mx.array): + n_batch = query.shape[0] + q = self.linear_q(query).reshape(n_batch, -1, self.h, self.d_k).transpose( + 0, 2, 1, 3 + ) + k = self.linear_k(key).reshape(n_batch, -1, self.h, self.d_k).transpose( + 0, 2, 1, 3 + ) + v = self.linear_v(value).reshape(n_batch, -1, self.h, self.d_k).transpose( + 0, 2, 1, 3 + ) + return q, k, v + + def __call__( + self, + query: mx.array, + key: mx.array, + value: mx.array, + mask: Optional[mx.array] = None, + pos_emb: Optional[mx.array] = None, + ) -> mx.array: + del pos_emb + q, k, v = self.forward_qkv(query, key, value) + attn_mask = None + if mask is not None and mask.size > 0: + # mask: (B,1,T) + attn_mask = (1.0 - mask.astype(mx.float32))[:, None, None, :] * (-1e9) + o = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.d_k**-0.5, mask=attn_mask) + o = o.transpose(0, 2, 1, 3).reshape(query.shape[0], -1, self.h * self.d_k) + return self.linear_out(o) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + def __init__(self, n_head: int, n_feat: int): + super().__init__(n_head, n_feat) + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + self.pos_bias_u = mx.zeros((self.h, self.d_k), dtype=mx.float32) + self.pos_bias_v = mx.zeros((self.h, self.d_k), dtype=mx.float32) + + def __call__( + self, + query: mx.array, + key: mx.array, + value: mx.array, + mask: Optional[mx.array] = None, + pos_emb: Optional[mx.array] = None, + ) -> mx.array: + if pos_emb is None: + raise ValueError("pos_emb required") + q, k, v = self.forward_qkv(query, key, value) + # q: (B,H,T,D) -> (B,T,H,D) + q_t = q.transpose(0, 2, 1, 3) + + p = self.linear_pos(pos_emb) + p = p.reshape(pos_emb.shape[0], -1, self.h, self.d_k).transpose(0, 2, 1, 3) + + q_u = (q_t + self.pos_bias_u).transpose(0, 2, 1, 3) + q_v = (q_t + self.pos_bias_v).transpose(0, 2, 1, 3) + + matrix_ac = mx.matmul(q_u, k.swapaxes(-2, -1)) + matrix_bd = mx.matmul(q_v, p.swapaxes(-2, -1)) + scores = (matrix_ac + matrix_bd) * (self.d_k**-0.5) + + if mask is not None and mask.size > 0: + # mask: (B, 1, T) + m = (mask == 0)[:, :, None, :] # (B, 1, 1, T) + scores = mx.where(m, -1e9, scores) + + probs = mx.softmax(scores, axis=-1) + out = mx.matmul(probs, v) + out = out.transpose(0, 2, 1, 3).reshape(query.shape[0], -1, self.h * self.d_k) + return self.linear_out(out) + + +class PositionwiseFeedForward(nn.Module): + def __init__(self, idim: int, hidden_units: int): + super().__init__() + self.w_1 = nn.Linear(idim, hidden_units) + self.w_2 = nn.Linear(hidden_units, idim) + self.activation = nn.SiLU() + + def __call__(self, xs: mx.array) -> mx.array: + return self.w_2(self.activation(self.w_1(xs))) + + +class ConvolutionModule(nn.Module): + def __init__(self, channels: int, kernel_size: int = 15, bias: bool = True): + super().__init__() + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, 1, 1, 0, bias=bias) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.LayerNorm(channels) + self.pointwise_conv2 = nn.Conv1d(channels, channels, 1, 1, 0, bias=bias) + self.activation = nn.SiLU() + + def __call__(self, x: mx.array, mask_pad: Optional[mx.array] = None) -> mx.array: + # x: (B, T, C) (MLX Conv1d expects channel-last) + if mask_pad is not None and mask_pad.size > 0: + # mask_pad: (B,1,T) + x = mx.where(mask_pad.transpose(0, 2, 1), x, 0.0) + + x = self.pointwise_conv1(x) + x = nn.glu(x, axis=-1) + x = self.depthwise_conv(x) + x = self.norm(x) + x = self.activation(x) + x = self.pointwise_conv2(x) + if mask_pad is not None and mask_pad.size > 0: + x = mx.where(mask_pad.transpose(0, 2, 1), x, 0.0) + return x + + +class ConformerEncoderLayer(nn.Module): + def __init__( + self, + size: int, + self_attn: nn.Module, + feed_forward: nn.Module, + conv_module: nn.Module, + ): + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.conv_module = conv_module + self.norm_ff = nn.LayerNorm(size, eps=1e-5) + self.norm_mha = nn.LayerNorm(size, eps=1e-5) + self.norm_conv = nn.LayerNorm(size, eps=1e-5) + self.norm_final = nn.LayerNorm(size, eps=1e-5) + + def __call__( + self, + x: mx.array, + mask: mx.array, + pos_emb: mx.array, + mask_pad: mx.array, + ) -> mx.array: + residual = x + x = self.norm_mha(x) + x = residual + self.self_attn(x, x, x, mask=mask_pad, pos_emb=pos_emb) + + residual = x + x = self.norm_conv(x) + x = residual + self.conv_module(x, mask_pad=mask_pad) + + residual = x + x = self.norm_ff(x) + x = residual + self.feed_forward(x) + + return self.norm_final(x) + + +@dataclass +class ConformerEncoderConfig: + input_size: int = 1024 + output_size: int = 512 + attention_heads: int = 8 + linear_units: int = 2048 + num_blocks: int = 6 + input_layer: str = "conv2d2" + + +class ConformerEncoder(nn.Module): + def __init__(self, cfg: ConformerEncoderConfig): + super().__init__() + self.embed = Conv2dSubsampling2(cfg.input_size, cfg.output_size) + self.after_norm = nn.LayerNorm(cfg.output_size, eps=1e-5) + + self.encoders = [ + ConformerEncoderLayer( + cfg.output_size, + RelPositionMultiHeadedAttention(cfg.attention_heads, cfg.output_size), + PositionwiseFeedForward(cfg.output_size, cfg.linear_units), + ConvolutionModule(cfg.output_size, kernel_size=15), + ) + for _ in range(cfg.num_blocks) + ] + + def __call__(self, xs: mx.array, xs_lens: mx.array) -> Tuple[mx.array, mx.array]: + # xs: (B, T, F) + T = xs.shape[1] + # masks: (B,1,T) + mask = mx.arange(T)[None, :] < xs_lens[:, None] + mask = mask[:, None, :] + xs, pos_emb, mask = self.embed(xs, mask, 0) + for layer in self.encoders: + xs = layer(xs, mask, pos_emb, mask) + xs = self.after_norm(xs) + return xs, mask diff --git a/mlx_audio/tts/models/indextts2/w2vbert.py b/mlx_audio/tts/models/indextts2/w2vbert.py new file mode 100644 index 000000000..75e9b0858 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/w2vbert.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten + + +def swish(x: mx.array) -> mx.array: + return x * nn.sigmoid(x) + + +class Conv1d(nn.Module): + """Minimal Conv1d with groups support (MLX layout). + + Expects input shape (B, T, C_in). + Weight shape is (C_out, K, C_in/groups). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + *, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + + if in_channels % groups != 0: + raise ValueError("in_channels must be divisible by groups") + + # Init matches torch-ish uniform + scale = math.sqrt(1.0 / (in_channels * kernel_size)) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(out_channels, kernel_size, in_channels // groups), + ) + if bias: + self.bias = mx.zeros((out_channels,)) + + def __call__(self, x: mx.array) -> mx.array: + y = mx.conv1d( + x, + self.weight, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + groups=self.groups, + ) + if "bias" in self: + y = y + self.bias + return y + + +@dataclass +class Wav2Vec2BertConfig: + # Core + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + intermediate_size: int = 4096 + feature_projection_input_dim: int = 160 + layer_norm_eps: float = 1e-5 + + # Attention + attention_dropout: float = 0.0 + position_embeddings_type: Optional[str] = "relative_key" # rotary|relative|relative_key|None + rotary_embedding_base: int = 10000 + max_source_positions: int = 5000 + left_max_position_embeddings: int = 64 + right_max_position_embeddings: int = 8 + + # Conformer conv + conv_depthwise_kernel_size: int = 31 + conformer_conv_dropout: float = 0.1 + + # Dropouts (inference only; kept for compatibility) + hidden_dropout: float = 0.0 + activation_dropout: float = 0.0 + feat_proj_dropout: float = 0.0 + + +class Wav2Vec2BertFeatureProjection(nn.Module): + def __init__(self, config: Wav2Vec2BertConfig): + super().__init__() + self.layer_norm = nn.LayerNorm( + config.feature_projection_input_dim, eps=config.layer_norm_eps + ) + self.projection = nn.Linear( + config.feature_projection_input_dim, config.hidden_size + ) + + def __call__(self, hidden_states: mx.array) -> Tuple[mx.array, mx.array]: + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + return hidden_states, norm_hidden_states + + +class Wav2Vec2BertFeedForward(nn.Module): + def __init__(self, config: Wav2Vec2BertConfig, *, hidden_size: Optional[int] = None): + super().__init__() + hs = hidden_size if hidden_size is not None else config.hidden_size + self.intermediate_dense = nn.Linear(hs, config.intermediate_size) + self.output_dense = nn.Linear(config.intermediate_size, hs) + + def __call__(self, hidden_states: mx.array) -> mx.array: + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = swish(hidden_states) + hidden_states = self.output_dense(hidden_states) + return hidden_states + + +class Wav2Vec2BertConvolutionModule(nn.Module): + def __init__(self, config: Wav2Vec2BertConfig): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError( + "conv_depthwise_kernel_size must be odd for SAME padding" + ) + + self.hidden_size = config.hidden_size + self.kernel_size = config.conv_depthwise_kernel_size + + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pointwise_conv1 = Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + bias=False, + ) + self.depthwise_conv = Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.conv_depthwise_kernel_size, + groups=config.hidden_size, + bias=False, + padding=0, + ) + self.depthwise_layer_norm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) + self.pointwise_conv2 = Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + bias=False, + ) + + def __call__( + self, hidden_states: mx.array, *, attention_mask: Optional[mx.array] = None + ) -> mx.array: + # hidden_states: (B, T, C) + hidden_states = self.layer_norm(hidden_states) + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask[:, :, None] + + hidden_states = self.pointwise_conv1(hidden_states) + a, b = hidden_states.split(2, axis=-1) + hidden_states = a * nn.sigmoid(b) + + # Causal left pad + pad_left = self.kernel_size - 1 + hidden_states = mx.pad(hidden_states, ((0, 0), (pad_left, 0), (0, 0))) + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.depthwise_layer_norm(hidden_states) + hidden_states = swish(hidden_states) + hidden_states = self.pointwise_conv2(hidden_states) + return hidden_states + + +class Wav2Vec2BertSelfAttention(nn.Module): + def __init__(self, config: Wav2Vec2BertConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // config.num_attention_heads + self.scale = self.head_size**-0.5 + + self.position_embeddings_type = config.position_embeddings_type + if self.position_embeddings_type not in (None, "relative_key"): + raise NotImplementedError( + f"position_embeddings_type={self.position_embeddings_type} not implemented" + ) + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + if self.position_embeddings_type == "relative_key": + self.left_max_position_embeddings = config.left_max_position_embeddings + self.right_max_position_embeddings = config.right_max_position_embeddings + num_positions = ( + self.left_max_position_embeddings + + self.right_max_position_embeddings + + 1 + ) + self.distance_embedding = nn.Embedding(num_positions, self.head_size) + + def _relative_key_scores(self, q: mx.array) -> mx.array: + # q: (B, H, T, Dh) -> (B, H, T, T) + T = q.shape[2] + pos_l = mx.arange(T).reshape(T, 1) + pos_r = mx.arange(T).reshape(1, T) + dist = pos_r - pos_l + dist = mx.clip( + dist, + -self.left_max_position_embeddings, + self.right_max_position_embeddings, + ) + dist = dist + self.left_max_position_embeddings + + pos_emb = self.distance_embedding(dist) # (T, T, Dh) + pos_emb = pos_emb.astype(q.dtype) + return mx.einsum("bhqd,qkd->bhqk", q, pos_emb) * self.scale + + def __call__( + self, + hidden_states: mx.array, + *, + attention_mask: Optional[mx.array] = None, + ) -> mx.array: + # hidden_states: (B, T, C) + B, T, C = hidden_states.shape + + q = self.linear_q(hidden_states) + k = self.linear_k(hidden_states) + v = self.linear_v(hidden_states) + + q = q.reshape(B, T, self.num_heads, self.head_size).transpose(0, 2, 1, 3) + k = k.reshape(B, T, self.num_heads, self.head_size).transpose(0, 2, 1, 3) + v = v.reshape(B, T, self.num_heads, self.head_size).transpose(0, 2, 1, 3) + + mask = None + if attention_mask is not None: + # attention_mask: (B, T) with 1=keep, 0=pad + # Convert to additive mask: (B, 1, 1, T) + mask = (1.0 - attention_mask.astype(mx.float32))[:, None, None, :] * (-1e9) + + if self.position_embeddings_type == "relative_key": + pos_scores = self._relative_key_scores(q) + mask = pos_scores if mask is None else (mask + pos_scores) + + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) + out = out.transpose(0, 2, 1, 3).reshape(B, T, C) + return self.linear_out(out) + + +class Wav2Vec2BertEncoderLayer(nn.Module): + def __init__(self, config: Wav2Vec2BertConfig): + super().__init__() + d = config.hidden_size + + self.ffn1_layer_norm = nn.LayerNorm(d, eps=config.layer_norm_eps) + self.ffn1 = Wav2Vec2BertFeedForward(config) + + self.self_attn_layer_norm = nn.LayerNorm(d, eps=config.layer_norm_eps) + self.self_attn = Wav2Vec2BertSelfAttention(config) + + self.conv_module = Wav2Vec2BertConvolutionModule(config) + + self.ffn2_layer_norm = nn.LayerNorm(d, eps=config.layer_norm_eps) + self.ffn2 = Wav2Vec2BertFeedForward(config) + self.final_layer_norm = nn.LayerNorm(d, eps=config.layer_norm_eps) + + def __call__( + self, + hidden_states: mx.array, + *, + attention_mask: Optional[mx.array] = None, + conv_attention_mask: Optional[mx.array] = None, + ) -> mx.array: + # 1) FFN1 + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + + # 2) Self-attn + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states + residual + + # 3) Conv + residual = hidden_states + hidden_states = self.conv_module(hidden_states, attention_mask=conv_attention_mask) + hidden_states = hidden_states + residual + + # 4) FFN2 + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class Wav2Vec2BertEncoder(nn.Module): + def __init__(self, config: Wav2Vec2BertConfig): + super().__init__() + self.layers = [Wav2Vec2BertEncoderLayer(config) for _ in range(config.num_hidden_layers)] + + def __call__( + self, + hidden_states: mx.array, + *, + attention_mask: Optional[mx.array] = None, + output_hidden_states: bool = False, + ) -> Tuple[mx.array, Optional[List[mx.array]]]: + all_hidden_states: Optional[List[mx.array]] = [] if output_hidden_states else None + + conv_attention_mask = attention_mask + + if attention_mask is not None: + hidden_states = hidden_states * attention_mask[:, :, None] + + for layer in self.layers: + if all_hidden_states is not None: + all_hidden_states.append(hidden_states) + hidden_states = layer( + hidden_states, + attention_mask=attention_mask, + conv_attention_mask=conv_attention_mask, + ) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask[:, :, None] + + if all_hidden_states is not None: + all_hidden_states.append(hidden_states) + + return hidden_states, all_hidden_states + + +class Wav2Vec2BertModel(nn.Module): + def __init__(self, config: Wav2Vec2BertConfig): + super().__init__() + self.config = config + self.feature_projection = Wav2Vec2BertFeatureProjection(config) + self.encoder = Wav2Vec2BertEncoder(config) + + # Present in HF checkpoints (used for SpecAugment during training). + # We keep it to allow strict weight loading. + self.masked_spec_embed = mx.zeros((config.hidden_size,), dtype=mx.float32) + + def __call__( + self, + input_features: mx.array, + *, + attention_mask: Optional[mx.array] = None, + output_hidden_states: bool = False, + ) -> Tuple[mx.array, Optional[List[mx.array]]]: + hidden_states, _ = self.feature_projection(input_features) + return self.encoder( + hidden_states, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + ) + + def sanitize(self, weights: dict[str, mx.array]) -> dict[str, mx.array]: + """Transpose conv weights from torch layout when needed.""" + curr = dict(tree_flatten(self.parameters())) + out = {} + for k, v in weights.items(): + if k in curr and v.ndim == 3 and curr[k].ndim == 3 and v.shape != curr[k].shape: + # Torch Conv1d: (O, I/groups, K) -> MLX: (O, K, I/groups) + if v.shape[0] == curr[k].shape[0] and v.shape[2] == curr[k].shape[1]: + v = v.transpose(0, 2, 1) + out[k] = v + return out diff --git a/mlx_audio/tts/models/indextts2/w2vbert_features.py b/mlx_audio/tts/models/indextts2/w2vbert_features.py new file mode 100644 index 000000000..04720d217 --- /dev/null +++ b/mlx_audio/tts/models/indextts2/w2vbert_features.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx + +from mlx_audio.dsp import compute_fbank_kaldi + + +@dataclass +class W2VBertFeatureExtractorConfig: + # From facebook/w2v-bert-2.0 preprocessor_config.json + sampling_rate: int = 16000 + num_mel_bins: int = 80 + stride: int = 2 + padding_value: float = 1.0 + + # Reasonable defaults (SeamlessM4TFeatureExtractor-like) + win_len: int = 400 # 25ms @ 16k + win_inc: int = 160 # 10ms @ 16k + preemphasis: float = 0.97 + dither: float = 0.0 + low_freq: float = 20.0 + high_freq: float = 0.0 + snip_edges: bool = False + + +def _pad_2d(x: mx.array, target_len: int, value: float) -> mx.array: + if x.shape[0] >= target_len: + return x + pad = target_len - x.shape[0] + return mx.pad(x, [(0, pad), (0, 0)], mode="constant", constant_values=value) + + +class W2VBertFeatureExtractor: + """MLX feature extractor compatible with Wav2Vec2BertModel inputs. + + Produces log-mel filterbank features (80 bins), then applies `stride=2` + stacking to yield a 160-dim feature vector per frame. + """ + + def __init__(self, cfg: Optional[W2VBertFeatureExtractorConfig] = None): + self.cfg = cfg or W2VBertFeatureExtractorConfig() + + def __call__( + self, audio: mx.array, *, lengths: Optional[mx.array] = None + ) -> Tuple[mx.array, mx.array]: + """Extract features. + + Args: + audio: (T,) or (B, T) float waveform in [-1, 1] + lengths: Optional (B,) lengths for each batch item. + + Returns: + input_features: (B, T', 160) + attention_mask: (B, T') with 1 for real frames, 0 for padded. + """ + + if audio.ndim == 1: + audio = audio[None, :] + if audio.ndim != 2: + raise ValueError(f"audio must be shape (T,) or (B,T), got {audio.shape}") + + B, T = audio.shape + if lengths is None: + lengths = mx.array([T] * B) + + feats = [] + frame_lens = [] + + for i in range(B): + wav = audio[i, : int(lengths[i].item())] + fb = compute_fbank_kaldi( + wav, + sample_rate=self.cfg.sampling_rate, + win_len=self.cfg.win_len, + win_inc=self.cfg.win_inc, + num_mels=self.cfg.num_mel_bins, + win_type="hamming", + preemphasis=self.cfg.preemphasis, + dither=self.cfg.dither, + snip_edges=self.cfg.snip_edges, + low_freq=self.cfg.low_freq, + high_freq=self.cfg.high_freq, + ) # (frames, 80) + + # Stride stacking: (frames, 80) -> (frames//2, 160) + stride = self.cfg.stride + n = (fb.shape[0] // stride) * stride + fb = fb[:n] + fb = fb.reshape(n // stride, stride * fb.shape[1]) + feats.append(fb) + frame_lens.append(fb.shape[0]) + + max_frames = int(max(frame_lens) if frame_lens else 0) + + padded = [_pad_2d(f, max_frames, self.cfg.padding_value) for f in feats] + input_features = mx.stack(padded, axis=0).astype(mx.float32) + + # Attention mask over frames + mask_rows = [] + for fl in frame_lens: + fl = int(fl) + if fl < 0 or fl > max_frames: + raise ValueError("Invalid frame length") + ones = mx.ones((fl,), dtype=mx.int32) + zeros = mx.zeros((max_frames - fl,), dtype=mx.int32) + mask_rows.append(mx.concatenate([ones, zeros], axis=0)) + mask = mx.stack(mask_rows, axis=0) if mask_rows else mx.zeros((B, 0), dtype=mx.int32) + + return input_features, mask diff --git a/mlx_audio/tts/models/indextts2/w2vbert_stats.py b/mlx_audio/tts/models/indextts2/w2vbert_stats.py new file mode 100644 index 000000000..ee3cbc12d --- /dev/null +++ b/mlx_audio/tts/models/indextts2/w2vbert_stats.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import mlx.core as mx +import mlx.nn as nn + + +class W2VBertStats(nn.Module): + """Mean/std normalization for W2V-BERT hidden states. + + The official IndexTTS2 pipeline normalizes hidden_states[17] as: + (feat - mean) / std + where mean/std are loaded from wav2vec2bert_stats.pt. + """ + + def __init__(self, dim: int = 1024): + super().__init__() + self.mean = mx.zeros((dim,), dtype=mx.float32) + self.std = mx.ones((dim,), dtype=mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + # x: (B, T, C) + return (x - self.mean[None, None, :]) / (self.std[None, None, :] + 1e-12) diff --git a/mlx_audio/tts/tests/test_indextts2_emotion.py b/mlx_audio/tts/tests/test_indextts2_emotion.py new file mode 100644 index 000000000..399049da1 --- /dev/null +++ b/mlx_audio/tts/tests/test_indextts2_emotion.py @@ -0,0 +1,64 @@ +import unittest + + +class TestIndexTTS2Emotion(unittest.TestCase): + def test_parse_json_english(self): + from mlx_audio.tts.indextts2.emotion import ( + EMOTION_KEYS, + normalize_emo_vector, + parse_emotion_response, + ) + + resp = '{"happy": 0.5, "angry": 0.1, "sad": 0.0, "afraid": 0.0, "disgusted": 0, "melancholic": 0.0, "surprised": 0.2, "calm": 0.0}' + emo = parse_emotion_response(resp) + vec_dict, vec = normalize_emo_vector(emo, apply_bias=False) + self.assertEqual(list(vec_dict.keys()), EMOTION_KEYS) + self.assertEqual(len(vec), 8) + self.assertAlmostEqual(vec_dict["happy"], 0.5) + self.assertAlmostEqual(vec_dict["angry"], 0.1) + self.assertAlmostEqual(vec_dict["surprised"], 0.2) + + def test_parse_json_chinese_keys(self): + from mlx_audio.tts.indextts2.emotion import ( + normalize_emo_vector, + parse_emotion_response, + ) + + resp = '{"高兴": 1.0, "愤怒": 0.2, "悲伤": 0.3, "自然": 0.1}' + emo = parse_emotion_response(resp) + vec_dict, _ = normalize_emo_vector(emo, apply_bias=False) + self.assertAlmostEqual(vec_dict["happy"], 1.0) + self.assertAlmostEqual(vec_dict["angry"], 0.2) + self.assertAlmostEqual(vec_dict["sad"], 0.3) + self.assertAlmostEqual(vec_dict["calm"], 0.1) + + def test_regex_fallback(self): + from mlx_audio.tts.indextts2.emotion import ( + normalize_emo_vector, + parse_emotion_response, + ) + + resp = "happy: 0.7, angry=0.2 calm:0" + emo = parse_emotion_response(resp) + vec_dict, _ = normalize_emo_vector(emo, apply_bias=False) + self.assertAlmostEqual(vec_dict["happy"], 0.7) + self.assertAlmostEqual(vec_dict["angry"], 0.2) + + def test_default_calm_when_empty(self): + from mlx_audio.tts.indextts2.emotion import normalize_emo_vector + + vec_dict, vec = normalize_emo_vector({}, apply_bias=False) + self.assertAlmostEqual(vec_dict["calm"], 1.0) + self.assertEqual(sum(vec), 1.0) + + def test_sum_clamp(self): + from mlx_audio.tts.indextts2.emotion import normalize_emo_vector + + emo = {"happy": 1.2, "angry": 1.2, "sad": 1.2} + vec_dict, vec = normalize_emo_vector(emo, apply_bias=False, max_sum=0.8) + self.assertLessEqual(sum(vec), 0.8 + 1e-6) + self.assertGreater(vec_dict["happy"], 0.0) + + +if __name__ == "__main__": + unittest.main()