diff --git a/mlx_audio/evals/README.md b/mlx_audio/evals/README.md new file mode 100644 index 000000000..6dde5ed0e --- /dev/null +++ b/mlx_audio/evals/README.md @@ -0,0 +1,267 @@ +# MLX-Audio Evaluations + +This module provides evaluation benchmarks for MLX-Audio models, enabling you to measure model performance against standardized datasets. + +## Available Benchmarks + +### InstructTTSEval + +[InstructTTSEval](https://arxiv.org/abs/2506.16381) is a benchmark for evaluating TTS systems' ability to follow complex natural-language style instructions. It measures how well models can synthesize speech that matches specified acoustic properties, styles, and personas. + +**Dataset**: [CaasiHUANG/InstructTTSEval](https://huggingface.co/datasets/CaasiHUANG/InstructTTSEval) + +## Installation + +Install mlx-audio with the evals dependencies: + +```bash +pip install mlx-audio[evals] +``` + +Or install the dependencies manually: + +```bash +pip install datasets google-generativeai openai +``` + +## Features + +- **Three Instruction Types**: + - **APS (Acoustic Property Specification)**: Low-level acoustic attribute descriptions (gender, pitch, speed, volume, age, emotion, etc.) + - **DSD (Detailed Style Description)**: High-level natural language style instructions + - **RP (Role-Play)**: Context-based scenario instructions for persona-driven synthesis + +- **Multilingual Support**: + - English (`en`): 1,000 samples + - Chinese (`zh`): 1,000 samples + +- **Flexible Evaluation**: + - LLM-as-judge scoring with Gemini or OpenAI + - Audio-only generation mode (skip scoring) + - Configurable sampling parameters + +- **Comprehensive Output**: + - Generated audio files (WAV format) + - Per-sample results (CSV) + - Summary statistics (JSON) + +## Usage + +### Basic Usage - Audio Generation Only + +Generate audio for all samples without LLM scoring: + +```bash +python -m mlx_audio.evals.instruct_tts_eval \ + --model mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16 \ + --split en \ + --evaluator skip \ + --save-audio \ + --output-dir results/instruct_tts_eval +``` + +### With Gemini Scoring + +Evaluate instruction-following with Gemini as the judge: + +```bash +python -m mlx_audio.evals.instruct_tts_eval \ + --model mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16 \ + --split en \ + --evaluator gemini \ + --api-key $GOOGLE_API_KEY \ + --save-audio \ + --output-dir results/instruct_tts_eval +``` + +### With OpenAI Scoring + +```bash +python -m mlx_audio.evals.instruct_tts_eval \ + --model mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16 \ + --split en \ + --evaluator openai \ + --api-key $OPENAI_API_KEY \ + --save-audio \ + --output-dir results/instruct_tts_eval +``` + +### Debugging with Limited Samples + +Test on a small subset before running full evaluation: + +```bash +python -m mlx_audio.evals.instruct_tts_eval \ + --model mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16 \ + --split en \ + --max-samples 10 \ + --evaluator skip \ + --save-audio \ + --verbose +``` + +### Evaluate Specific Instruction Types + +```bash +python -m mlx_audio.evals.instruct_tts_eval \ + --model mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16 \ + --split en \ + --instruction-types APS DSD \ + --evaluator gemini \ + --api-key $GOOGLE_API_KEY +``` + +### Chinese Evaluation + +```bash +python -m mlx_audio.evals.instruct_tts_eval \ + --model mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16 \ + --split zh \ + --evaluator gemini \ + --api-key $GOOGLE_API_KEY +``` + +## Command-Line Options + +| Option | Description | Default | +|--------|-------------|---------| +| `--model` | Path or HuggingFace repo ID of the TTS model | Required | +| `--dataset` | HuggingFace dataset name | `CaasiHUANG/InstructTTSEval` | +| `--split` | Dataset split (`en` or `zh`) | `en` | +| `--instruction-types` | Instruction types to evaluate (`APS`, `DSD`, `RP`) | All three | +| `--max-samples` | Maximum samples to evaluate (for debugging) | None (all) | +| `--output-dir` | Directory to save results | `results/instruct_tts_eval` | +| `--max-tokens` | Maximum tokens to generate | `2048` | +| `--temperature` | Sampling temperature | `0.7` | +| `--voice` | Voice/speaker name (model-specific) | Auto-detected | +| `--evaluator` | LLM evaluator (`gemini`, `openai`, `skip`) | `skip` | +| `--api-key` | API key for evaluator service | None | +| `--save-audio` | Save generated audio files | False | +| `--verbose` | Print detailed output | False | +| `--seed` | Random seed | `42` | + +## Evaluation Metrics + +### Scoring Methodology + +InstructTTSEval uses an **LLM-as-judge** approach with binary scoring: + +| Score | Criteria | +|-------|----------| +| **TRUE** | The sample's primary style attributes (gender, pitch, rate, emotion) align with the instruction without conflict | +| **FALSE** | At least one key style attribute clearly conflicts with the instruction, or the overall style deviates from the prompt | + +The final score for each instruction type is the **percentage of TRUE responses** across all samples. + +### Reported Scores + +Reference scores for Qwen3-TTS models on InstructTTSEval: + +#### Qwen3-TTS-12Hz-1.7B-CustomVoice + +| Language | APS | DSD | RP | +|----------|-----|-----|-----| +| Chinese | 83.0 | 77.8 | 61.2 | +| English | 77.3 | 77.1 | 63.7 | + +#### Qwen3-TTS-12Hz-1.7B-VoiceDesign + +| Language | APS | DSD | RP | +|----------|-----|-----|-----| +| Chinese | 85.2 | 81.1 | 65.1 | +| English | 82.9 | 82.4 | 68.4 | + +### Human-LLM Agreement + +The benchmark authors validated Gemini's evaluation against human annotators: + +| Instruction Type | Agreement Rate | +|------------------|----------------| +| APS | 87% | +| DSD | 79% | +| RP | 71% | +| **Average** | **79%** | + +## Output Files + +After running an evaluation, you'll find: + +``` +results/instruct_tts_eval/ +├── {model_name}_InstructTTSEval_{split}.csv # Per-sample results +├── {model_name}_InstructTTSEval_{split}.json # Summary statistics +└── audio/ # Generated audio (if --save-audio) + ├── en_0_APS.wav + ├── en_0_DSD.wav + ├── en_0_RP.wav + └── ... +``` + +### CSV Format + +| Column | Description | +|--------|-------------| +| `id` | Sample identifier (e.g., `en_0`) | +| `instruction_type` | Type of instruction (`APS`, `DSD`, `RP`) | +| `text` | Text that was synthesized | +| `instruction` | Style instruction given to the model | +| `generated` | Whether audio was successfully generated | +| `score` | Evaluation result (`True`, `False`, or empty if skipped) | + +### JSON Summary + +```json +{ + "model": "mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16", + "dataset": "CaasiHUANG/InstructTTSEval", + "split": "en", + "instruction_types": ["APS", "DSD", "RP"], + "evaluator": "gemini", + "total_samples": 1000, + "scores": { + "APS": {"correct": 773, "total": 1000, "accuracy": 77.3}, + "DSD": {"correct": 771, "total": 1000, "accuracy": 77.1}, + "RP": {"correct": 637, "total": 1000, "accuracy": 63.7} + }, + "average_score": 72.7 +} +``` + +## Python API + +You can also use the evaluation functions programmatically: + +```python +from mlx_audio.evals.instruct_tts_eval import ( + load_dataset, + run_inference, + evaluate_with_llm, + save_audio, +) +from mlx_audio.tts.utils import load as load_tts_model + +# Load model +model = load_tts_model("mlx-community/Qwen3-TTS-12Hz-1.7B-CustomVoice-bf16") + +# Load dataset +dataset = load_dataset(split="en", max_samples=10) + +# Generate and evaluate +for sample in dataset: + audio = run_inference( + model=model, + text=sample["text"], + instruction=sample["APS"], + voice="vivian", + lang_code="en", + ) + + if audio is not None: + save_audio(audio, "output.wav", sample_rate=model.sample_rate) +``` + +## References + +- **Paper**: [InstructTTSEval: Benchmarking Complex Natural-Language Instruction Following in Text-to-Speech Systems](https://arxiv.org/abs/2506.16381) +- **Dataset**: [CaasiHUANG/InstructTTSEval](https://huggingface.co/datasets/CaasiHUANG/InstructTTSEval) +- **GitHub**: [InstructTTSEval](https://github.com/KexinHUANG19/InstructTTSEval) diff --git a/mlx_audio/evals/__init__.py b/mlx_audio/evals/__init__.py new file mode 100644 index 000000000..caf473f24 --- /dev/null +++ b/mlx_audio/evals/__init__.py @@ -0,0 +1 @@ +__all__ = ["instruct_tts_eval"] diff --git a/mlx_audio/evals/instruct_tts_eval.py b/mlx_audio/evals/instruct_tts_eval.py new file mode 100644 index 000000000..5c4a1fed6 --- /dev/null +++ b/mlx_audio/evals/instruct_tts_eval.py @@ -0,0 +1,605 @@ +""" +InstructTTSEval: Evaluation of instruction-following capabilities in TTS systems. + +This module implements the InstructTTSEval benchmark for evaluating TTS models on their +ability to follow complex natural-language style instructions. + +The benchmark has three task types: +- APS (Acoustic Property Specification): Low-level acoustic attribute descriptions +- DSD (Detailed Style Description): High-level style instructions +- RP (Role-Play): Context-based scenario instructions + +Reference: https://arxiv.org/abs/2506.16381 +Dataset: https://huggingface.co/datasets/CaasiHUANG/InstructTTSEval +""" + +import argparse +import csv +import json +import logging +import os +import random +import tempfile +from pathlib import Path +from typing import Optional + +import mlx.core as mx +import numpy as np +from tqdm import tqdm + +from mlx_audio.audio_io import write as audio_write +from mlx_audio.tts.utils import load as load_tts_model + +from .utils import inference + +# Instruction types in InstructTTSEval +INSTRUCTION_TYPES = ["APS", "DSD", "RP"] + + +def load_dataset( + dataset_name: str = "CaasiHUANG/InstructTTSEval", + split: str = "en", + streaming: bool = False, + max_samples: Optional[int] = None, +): + """ + Load the InstructTTSEval dataset from HuggingFace. + + Args: + dataset_name: HuggingFace dataset name. + split: Dataset split ('en' for English, 'zh' for Chinese). + streaming: Whether to use streaming mode. + max_samples: Maximum number of samples to load (for debugging). + + Returns: + Dataset object. + """ + from datasets import load_dataset as hf_load_dataset + + dataset = hf_load_dataset(dataset_name, split=split, streaming=streaming) + + # Remove audio column to avoid decoding issues because of torchcodec not being installed + if "reference_audio" in dataset.column_names: + dataset = dataset.remove_columns(["reference_audio"]) + + if max_samples and not streaming: + dataset = dataset.select(range(min(max_samples, len(dataset)))) + elif max_samples and streaming: + dataset = dataset.take(max_samples) + + return dataset + + +def save_audio(audio: mx.array, path: str, sample_rate: int = 24000) -> None: + """Save audio to file.""" + audio_write(path, np.array(audio), sample_rate, format="wav") + + +def get_voice_for_model(model, model_type: str, lang_code: str = "en") -> Optional[str]: + """ + Get an appropriate voice/speaker for the model. + + Args: + model: Loaded TTS model. + model_type: Type of model (e.g., 'CustomVoice', 'VoiceDesign'). + lang_code: Language code ('en' or 'zh'). + + Returns: + Voice name or None. + """ + if hasattr(model, "config"): + tts_model_type = getattr(model.config, "tts_model_type", None) + if tts_model_type == "custom_voice": + # CustomVoice models have predefined speakers + # Try to get available speakers from model + if hasattr(model, "available_speakers"): + speakers = model.available_speakers + if speakers: + # Prefer English-sounding names for English, Chinese for Chinese + if lang_code == "en": + for name in ["vivian", "ryan", "aiden", "eric", "dylan"]: + if name in speakers: + return name + else: # Chinese + for name in ["uncle_fu", "serena", "ono_anna", "sohee"]: + if name in speakers: + return name + return speakers[0] # Fall back to first available + # Default speakers based on model + return "vivian" # Common default for Qwen3-TTS CustomVoice + elif tts_model_type == "voice_design": + # VoiceDesign models don't need a voice parameter + return None + return None + + +def run_inference( + model, + text: str, + instruction: str, + voice: Optional[str] = None, + lang_code: str = "auto", + max_tokens: int = 2048, + temperature: float = 0.7, + verbose: bool = False, +) -> Optional[mx.array]: + """ + Run TTS inference with the given instruction. + + Args: + model: Loaded TTS model. + text: Text to synthesize. + instruction: Style instruction (APS, DSD, or RP content). + voice: Voice/speaker name. + lang_code: Language code. + max_tokens: Maximum tokens to generate. + temperature: Sampling temperature. + verbose: Whether to print verbose output. + + Returns: + Generated audio array or None if generation failed. + """ + try: + results = list( + inference( + model=model, + text=text, + voice=voice, + instruct=instruction, + lang_code=lang_code, + max_tokens=max_tokens, + temperature=temperature, + verbose=verbose, + ) + ) + + if results: + # Concatenate all audio segments + audio_segments = [r.audio for r in results] + return mx.concatenate(audio_segments, axis=0) + return None + except Exception as e: + logging.error(f"Inference error: {e}") + return None + + +def evaluate_with_llm( + audio_path: str, + instruction: str, + instruction_type: str, + evaluator: str = "gemini", + api_key: Optional[str] = None, +) -> bool: + """ + Evaluate generated audio against instruction using LLM-as-judge. + + The evaluation uses a binary rubric: + - True: Primary style attributes align with the prompt + - False: At least one key style attribute conflicts with the prompt + + Args: + audio_path: Path to generated audio file. + instruction: The instruction that was given to the TTS model. + instruction_type: Type of instruction (APS, DSD, RP). + evaluator: Evaluator to use ('gemini', 'openai', 'local'). + api_key: API key for the evaluator service. + + Returns: + Boolean indicating whether the audio follows the instruction. + """ + if evaluator == "skip": + # Skip LLM evaluation, just return True (for audio-only generation runs) + return True + + # Evaluation prompt based on InstructTTSEval paper + eval_prompt = f"""You are evaluating a text-to-speech (TTS) system's ability to follow style instructions. + +The TTS system was given this instruction: +--- +{instruction} +--- + +Listen to the generated audio and determine if it follows the instruction. + +Scoring rubric: +- TRUE: The sample's primary style attributes (e.g., gender, pitch, rate, emotion) align with the instruction, without conflict. +- FALSE: At least one key style attribute clearly conflicts with the instruction, or the overall style deviates from the instruction. + +Respond with only TRUE or FALSE.""" + + if evaluator == "gemini": + return _evaluate_with_gemini(audio_path, eval_prompt, api_key) + elif evaluator == "openai": + return _evaluate_with_openai(audio_path, eval_prompt, api_key) + else: + logging.warning(f"Unknown evaluator: {evaluator}, returning True") + return True + + +def _evaluate_with_gemini( + audio_path: str, prompt: str, api_key: Optional[str] = None +) -> bool: + """Evaluate using Google's Gemini API.""" + try: + import google.generativeai as genai + + if api_key: + genai.configure(api_key=api_key) + else: + # Try to get from environment + api_key = os.environ.get("GOOGLE_API_KEY") + if api_key: + genai.configure(api_key=api_key) + else: + logging.warning("No Gemini API key provided, skipping evaluation") + return True + + model = genai.GenerativeModel("gemini-2.0-flash") + + # Upload audio file + audio_file = genai.upload_file(audio_path) + + response = model.generate_content([prompt, audio_file]) + result = response.text.strip().upper() + + return result == "TRUE" + except Exception as e: + logging.error(f"Gemini evaluation error: {e}") + return True # Default to True on error + + +def _evaluate_with_openai( + audio_path: str, prompt: str, api_key: Optional[str] = None +) -> bool: + """Evaluate using OpenAI's API with audio capabilities.""" + try: + import base64 + + from openai import OpenAI + + client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY")) + + # Read and encode audio + with open(audio_path, "rb") as f: + audio_data = base64.b64encode(f.read()).decode("utf-8") + + response = client.chat.completions.create( + model="gpt-4o-audio-preview", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "input_audio", + "input_audio": {"data": audio_data, "format": "wav"}, + }, + ], + } + ], + ) + + result = response.choices[0].message.content.strip().upper() + return result == "TRUE" + except Exception as e: + logging.error(f"OpenAI evaluation error: {e}") + return True # Default to True on error + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Evaluate TTS models on InstructTTSEval benchmark" + ) + parser.add_argument( + "--model", + type=str, + required=True, + help="Path or HuggingFace repo ID of the TTS model", + ) + parser.add_argument( + "--dataset", + type=str, + default="CaasiHUANG/InstructTTSEval", + help="HuggingFace dataset name", + ) + parser.add_argument( + "--split", + type=str, + default="en", + choices=["en", "zh"], + help="Dataset split to evaluate on (en=English, zh=Chinese)", + ) + parser.add_argument( + "--instruction-types", + type=str, + nargs="+", + default=["APS", "DSD", "RP"], + choices=["APS", "DSD", "RP"], + help="Instruction types to evaluate", + ) + parser.add_argument( + "--streaming", + action="store_true", + help="Use streaming dataset loading", + ) + parser.add_argument( + "--max-samples", + type=int, + default=None, + help="Maximum number of samples to evaluate (for debugging)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="results/instruct_tts_eval", + help="Directory to save results", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=2048, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Temperature for generation", + ) + parser.add_argument( + "--voice", + type=str, + default=None, + help="Voice/speaker name (model-specific)", + ) + parser.add_argument( + "--evaluator", + type=str, + default="skip", + choices=["gemini", "openai", "skip"], + help="LLM evaluator to use for scoring (skip=audio generation only)", + ) + parser.add_argument( + "--api-key", + type=str, + default=None, + help="API key for the evaluator service", + ) + parser.add_argument( + "--save-audio", + action="store_true", + help="Save generated audio files", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Print detailed output for debugging", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + random.seed(args.seed) + + # Setup logging + logging.basicConfig( + level=logging.INFO if args.verbose else logging.WARNING, + format="%(asctime)s - %(levelname)s - %(message)s", + ) + + # Create output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + if args.save_audio: + audio_dir = output_dir / "audio" + audio_dir.mkdir(parents=True, exist_ok=True) + + # Load model + # Determine language code from split + lang_code = "en" if args.split == "en" else "zh" + + logging.info(f"Loading model from {args.model}") + print(f"Loading model: {args.model}") + model = load_tts_model(args.model) + + # Determine voice if not specified + voice = args.voice + if voice is None: + voice = get_voice_for_model(model, args.model, lang_code) + if voice: + print(f"Using voice: {voice}") + + # Load dataset + logging.info(f"Loading dataset {args.dataset}, split {args.split}") + print(f"Loading dataset: {args.dataset} (split={args.split})") + dataset = load_dataset( + args.dataset, + split=args.split, + streaming=args.streaming, + max_samples=args.max_samples, + ) + + # Initialize results tracking + results = [] + scores = {inst_type: {"correct": 0, "total": 0} for inst_type in args.instruction_types} + + # Evaluate each sample + model_name = args.model.split("/")[-1] + + # Get total count for progress bar + try: + total = len(dataset) + except TypeError: + total = args.max_samples if args.max_samples else None + + for idx, sample in enumerate(tqdm(dataset, desc="Evaluating", total=total)): + sample_id = sample.get("id", f"{args.split}_{idx}") + text = sample["text"] + + # Process each instruction type + for inst_type in args.instruction_types: + instruction = sample[inst_type] + + # Run inference + audio = run_inference( + model=model, + text=text, + instruction=instruction, + voice=voice, + lang_code=lang_code, + max_tokens=args.max_tokens, + temperature=args.temperature, + verbose=args.verbose, + ) + + if audio is None: + logging.warning(f"Failed to generate audio for sample {sample_id} ({inst_type})") + result = { + "id": sample_id, + "instruction_type": inst_type, + "text": text, + "instruction": instruction[:200] + "..." if len(instruction) > 200 else instruction, + "generated": False, + "score": False, + } + results.append(result) + scores[inst_type]["total"] += 1 + continue + + # Save audio if requested + audio_path = None + if args.save_audio: + audio_path = str(audio_dir / f"{sample_id}_{inst_type}.wav") + save_audio(audio, audio_path, sample_rate=model.sample_rate) + + # Evaluate with LLM if not skipping + if args.evaluator != "skip": + # Need to save audio temporarily for evaluation + if audio_path is None: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: + audio_path = f.name + save_audio(audio, audio_path, sample_rate=model.sample_rate) + + is_correct = evaluate_with_llm( + audio_path=audio_path, + instruction=instruction, + instruction_type=inst_type, + evaluator=args.evaluator, + api_key=args.api_key, + ) + + # Clean up temp file + if not args.save_audio and audio_path: + os.unlink(audio_path) + else: + is_correct = None # No evaluation + + # Record result + result = { + "id": sample_id, + "instruction_type": inst_type, + "text": text, + "instruction": instruction[:200] + "..." if len(instruction) > 200 else instruction, + "generated": True, + "score": is_correct, + } + results.append(result) + + scores[inst_type]["total"] += 1 + if is_correct: + scores[inst_type]["correct"] += 1 + + # Progress update every 10 samples + if (idx + 1) % 10 == 0: + logging.info(f"Processed {idx + 1} samples") + + # Calculate final scores + final_scores = {} + for inst_type in args.instruction_types: + total = scores[inst_type]["total"] + correct = scores[inst_type]["correct"] + if total > 0 and args.evaluator != "skip": + final_scores[inst_type] = (correct / total) * 100 + else: + final_scores[inst_type] = None + + # Calculate average score + valid_scores = [s for s in final_scores.values() if s is not None] + avg_score = sum(valid_scores) / len(valid_scores) if valid_scores else None + + # Save results to CSV + results_file = output_dir / f"{model_name}_InstructTTSEval_{args.split}.csv" + fieldnames = ["id", "instruction_type", "text", "instruction", "generated", "score"] + + with open(results_file, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(results) + + # Save summary + summary = { + "model": args.model, + "dataset": args.dataset, + "split": args.split, + "instruction_types": args.instruction_types, + "evaluator": args.evaluator, + "total_samples": len(set(r["id"] for r in results)), + "scores": { + inst_type: { + "correct": scores[inst_type]["correct"], + "total": scores[inst_type]["total"], + "accuracy": final_scores[inst_type], + } + for inst_type in args.instruction_types + }, + "average_score": avg_score, + } + + summary_file = output_dir / f"{model_name}_InstructTTSEval_{args.split}.json" + with open(summary_file, "w") as f: + json.dump(summary, f, indent=2) + + # Print results + print(f"\n{'='*80}") + print("InstructTTSEval Results") + print(f"{'='*80}") + print(f"Model: {args.model}") + print(f"Split: {args.split}") + print(f"Evaluator: {args.evaluator}") + print(f"Total Samples: {summary['total_samples']}") + + if args.evaluator != "skip": + print(f"\n{'-'*80}") + print("Scores by Instruction Type:") + print(f"{'-'*80}") + for inst_type in args.instruction_types: + score_info = summary["scores"][inst_type] + print( + f" {inst_type}: {score_info['correct']}/{score_info['total']} " + f"({score_info['accuracy']:.2f}%)" + ) + + if avg_score is not None: + print(f"\nAverage Score: {avg_score:.2f}%") + else: + print("\nScoring skipped (--evaluator=skip)") + print(f"Generated audio for {sum(s['total'] for s in scores.values())} instruction-text pairs") + + print(f"{'='*80}") + print(f"\nResults saved to {results_file}") + print(f"Summary saved to {summary_file}") + if args.save_audio: + print(f"Audio files saved to {audio_dir}") + + +if __name__ == "__main__": + main() diff --git a/mlx_audio/evals/utils.py b/mlx_audio/evals/utils.py new file mode 100644 index 000000000..68eb77fc1 --- /dev/null +++ b/mlx_audio/evals/utils.py @@ -0,0 +1,90 @@ +"""Common utilities for MLX-Audio evaluations.""" + +from typing import Generator, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from mlx_audio.tts.models.base import GenerationResult +from mlx_audio.tts.utils import load as load_tts_model +from mlx_audio.utils import load_audio + + +def inference( + model: nn.Module, + text: str, + voice: Optional[str] = None, + instruct: Optional[str] = None, + ref_audio: Optional[mx.array] = None, + ref_text: Optional[str] = None, + lang_code: str = "auto", + max_tokens: int = 2048, + temperature: float = 0.7, + speed: float = 1.0, + verbose: bool = False, +) -> Generator[GenerationResult, None, None]: + """ + Run TTS inference on a single text input. + + Args: + model: Loaded TTS model. + text: Text to synthesize. + voice: Voice/speaker name for the model. + instruct: Instruction for style control (CustomVoice/VoiceDesign models). + ref_audio: Reference audio for voice cloning. + ref_text: Transcript of reference audio. + lang_code: Language code (e.g., 'en', 'zh', 'auto'). + max_tokens: Maximum tokens to generate. + temperature: Sampling temperature. + speed: Speech speed multiplier. + verbose: Whether to print verbose output. + + Returns: + Generator yielding GenerationResult objects. + """ + gen_kwargs = dict( + text=text, + voice=voice, + instruct=instruct, + ref_audio=ref_audio, + ref_text=ref_text, + lang_code=lang_code, + max_tokens=max_tokens, + temperature=temperature, + speed=speed, + verbose=verbose, + stream=False, + ) + + # Filter out None values + gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} + + return model.generate(**gen_kwargs) + + +def get_audio_from_result(result: GenerationResult) -> mx.array: + """Extract audio array from a GenerationResult.""" + return result.audio + + +def load_reference_audio( + audio_path: str, + sample_rate: int = 24000, + volume_normalize: bool = False, +) -> mx.array: + """ + Load a reference audio file. + + Args: + audio_path: Path to the audio file. + sample_rate: Target sample rate. + volume_normalize: Whether to normalize volume. + + Returns: + Audio array. + """ + return load_audio( + audio_path, + sample_rate=sample_rate, + volume_normalize=volume_normalize, + ) diff --git a/pyproject.toml b/pyproject.toml index fb6b9e460..8e3455904 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,11 +99,19 @@ dev = [ "pre-commit>=3.7.0", ] +# Evaluation dependencies +evals = [ + "datasets>=4.0.0", + "google-generativeai>=0.8.0", + "openai>=1.0.0", +] + [project.scripts] "mlx_audio.convert" = "mlx_audio.convert:main" "mlx_audio.stt.generate" = "mlx_audio.stt.generate:main" "mlx_audio.tts.generate" = "mlx_audio.tts.generate:main" "mlx_audio.server" = "mlx_audio.server:main" +"mlx_audio.evals.instruct_tts_eval" = "mlx_audio.evals.instruct_tts_eval:main" [project.urls] Homepage = "https://github.com/Blaizzy/mlx-audio"