diff --git a/src/voxcpm/cli.py b/src/voxcpm/cli.py index 8f2e41a..f6d40d6 100644 --- a/src/voxcpm/cli.py +++ b/src/voxcpm/cli.py @@ -11,11 +11,6 @@ import sys from pathlib import Path -import soundfile as sf - -from voxcpm.core import VoxCPM - - DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2" # ----------------------------- @@ -173,7 +168,9 @@ def validate_batch_args(args, parser): # ----------------------------- -def load_model(args) -> VoxCPM: +def load_model(args): + from voxcpm.core import VoxCPM + print("Loading VoxCPM model...", file=sys.stderr) zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get( @@ -266,6 +263,8 @@ def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None and (args.prompt_audio is not None or args.reference_audio is not None), ) + import soundfile as sf + sf.write(str(output_path), audio_array, model.tts_model.sample_rate) duration = len(audio_array) / model.tts_model.sample_rate @@ -288,7 +287,27 @@ def cmd_clone(args, parser): ) +def cmd_validate(args, parser): + from voxcpm.training.validate import ( + print_validation_report, + validate_manifest, + ) + + manifest = str(require_file_exists(args.manifest, parser, "manifest file")) + result = validate_manifest( + manifest_path=manifest, + sample_rate=args.sample_rate, + max_samples=args.max_samples, + verbose=args.verbose, + ) + print_validation_report(result, manifest) + if not result.is_valid: + sys.exit(1) + + def cmd_batch(args, parser): + import soundfile as sf + input_file = require_file_exists(args.input, parser, "input file") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -532,6 +551,30 @@ def _build_parser(): _add_model_args(batch_parser) _add_lora_args(batch_parser) + # Validate subcommand + validate_parser = subparsers.add_parser( + "validate", + help="Validate a training data manifest (JSONL) before fine-tuning", + ) + validate_parser.add_argument( + "--manifest", "-m", required=True, help="Path to JSONL training manifest" + ) + validate_parser.add_argument( + "--sample-rate", + type=int, + default=16_000, + help="Expected audio sample rate in Hz (default: 16000)", + ) + validate_parser.add_argument( + "--max-samples", + type=int, + default=0, + help="Maximum number of samples to validate (0 = all, default: 0)", + ) + validate_parser.add_argument( + "--verbose", "-v", action="store_true", help="Print per-sample progress" + ) + # Legacy root arguments parser.add_argument("--input", "-i", help="Input text file (batch mode only)") parser.add_argument( @@ -584,6 +627,9 @@ def main(): parser = _build_parser() args = parser.parse_args() + if args.command == "validate": + return cmd_validate(args, parser) + validate_ranges(args, parser) if args.command == "design": diff --git a/src/voxcpm/training/__init__.py b/src/voxcpm/training/__init__.py index 82e9ed6..04800d8 100644 --- a/src/voxcpm/training/__init__.py +++ b/src/voxcpm/training/__init__.py @@ -15,6 +15,7 @@ BatchProcessor, ) from .state import TrainingState +from .validate import validate_manifest, ValidationResult __all__ = [ "Accelerator", @@ -24,4 +25,6 @@ "TrainingState", "load_audio_text_datasets", "build_dataloader", + "validate_manifest", + "ValidationResult", ] diff --git a/src/voxcpm/training/validate.py b/src/voxcpm/training/validate.py new file mode 100644 index 0000000..bad32a0 --- /dev/null +++ b/src/voxcpm/training/validate.py @@ -0,0 +1,309 @@ +""" +Pre-flight validation for VoxCPM training data manifests. + +Validates JSONL manifest files before starting expensive fine-tuning jobs, +catching format issues, missing files, and data quality problems early. +""" + +import json +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + + +@dataclass +class ValidationResult: + """Structured result of a manifest validation run.""" + + total_samples: int = 0 + valid_samples: int = 0 + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + audio_durations: List[float] = field(default_factory=list) + text_lengths: List[int] = field(default_factory=list) + has_ref_audio: int = 0 + + @property + def is_valid(self) -> bool: + return len(self.errors) == 0 and self.valid_samples > 0 + + +def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]: + """Check if an audio file exists, is readable, and matches expected sample rate. + + Returns an error message, or None if the file is valid. + """ + if not os.path.isfile(audio_path): + return f"Audio file not found: {audio_path}" + try: + import soundfile as sf + + info = sf.info(audio_path) + if info.frames == 0: + return f"Audio file is empty: {audio_path}" + if info.samplerate != sample_rate: + return ( + f"Sample rate mismatch in {audio_path}: " + f"expected {sample_rate} Hz, got {info.samplerate} Hz" + ) + return None + except ImportError: + # soundfile not available; just check existence + return None + except Exception as e: + return f"Cannot read audio file {audio_path}: {e}" + + +def _get_audio_duration(audio_path: str) -> Optional[float]: + """Get audio duration in seconds. Returns None if unavailable.""" + try: + import soundfile as sf + + info = sf.info(audio_path) + return info.duration + except Exception: + return None + + +def validate_manifest( + manifest_path: str, + sample_rate: int = 16_000, + max_samples: int = 0, + verbose: bool = False, +) -> ValidationResult: + """Validate a JSONL training manifest file. + + Checks: + 1. File exists and is readable + 2. Each line is valid JSON + 3. Required columns present (text, audio) + 4. Audio files exist and are readable + 5. Text content is non-empty + 6. Collects duration and text length statistics + 7. Validates optional ref_audio column + + Args: + manifest_path: Path to the JSONL manifest file. + sample_rate: Expected audio sample rate (for informational purposes). + max_samples: Maximum number of samples to validate (0 = all). + verbose: Print per-sample progress. + + Returns: + ValidationResult with errors, warnings, and statistics. + """ + result = ValidationResult() + path = Path(manifest_path) + + if not path.exists(): + result.errors.append(f"Manifest file not found: {manifest_path}") + return result + + if not path.is_file(): + result.errors.append(f"Manifest path is not a file: {manifest_path}") + return result + + manifest_dir = path.parent + + try: + with open(path, "r", encoding="utf-8") as f: + lines = f.readlines() + except Exception as e: + result.errors.append(f"Cannot read manifest file: {e}") + return result + + if not lines: + result.errors.append("Manifest file is empty") + return result + + samples_to_check = len(lines) + if max_samples > 0: + samples_to_check = min(samples_to_check, max_samples) + + missing_audio_count = 0 + empty_text_count = 0 + + for i, line in enumerate(lines[:samples_to_check]): + line = line.strip() + if not line: + continue + + result.total_samples += 1 + + # Check JSON validity + try: + entry = json.loads(line) + except json.JSONDecodeError as e: + result.errors.append(f"Line {i + 1}: Invalid JSON — {e}") + continue + + if not isinstance(entry, dict): + result.errors.append(f"Line {i + 1}: Expected JSON object, got {type(entry).__name__}") + continue + + # Check required columns + has_error = False + + if "text" not in entry: + result.errors.append(f"Line {i + 1}: Missing required column 'text'") + has_error = True + + if "audio" not in entry: + result.errors.append(f"Line {i + 1}: Missing required column 'audio'") + has_error = True + + if has_error: + continue + + # Validate text + text = entry["text"] + if not isinstance(text, str) or not text.strip(): + empty_text_count += 1 + if empty_text_count <= 5: + result.warnings.append(f"Line {i + 1}: Empty or non-string text") + else: + result.text_lengths.append(len(text)) + + # Validate audio path + audio_path = entry["audio"] + if isinstance(audio_path, dict): + # HuggingFace Audio format with {"path": ..., "array": ...} + audio_path = audio_path.get("path", "") + + if isinstance(audio_path, str) and audio_path: + # Resolve relative paths against manifest directory + if not os.path.isabs(audio_path): + audio_path = str(manifest_dir / audio_path) + + audio_error = _check_audio_file(audio_path, sample_rate) + if audio_error: + missing_audio_count += 1 + if missing_audio_count <= 5: + result.errors.append(f"Line {i + 1}: {audio_error}") + has_error = True + else: + duration = _get_audio_duration(audio_path) + if duration is not None: + result.audio_durations.append(duration) + if duration < 0.3: + result.warnings.append( + f"Line {i + 1}: Very short audio ({duration:.2f}s)" + ) + elif duration > 30.0: + result.warnings.append( + f"Line {i + 1}: Very long audio ({duration:.1f}s), may cause OOM" + ) + else: + result.errors.append(f"Line {i + 1}: Invalid audio path") + has_error = True + + # Validate optional ref_audio + if "ref_audio" in entry: + ref_path = entry["ref_audio"] + if isinstance(ref_path, dict): + ref_path = ref_path.get("path", "") + if isinstance(ref_path, str) and ref_path: + if not os.path.isabs(ref_path): + ref_path = str(manifest_dir / ref_path) + if os.path.isfile(ref_path): + result.has_ref_audio += 1 + else: + result.warnings.append( + f"Line {i + 1}: ref_audio file not found: {ref_path}" + ) + + if not has_error: + result.valid_samples += 1 + + if verbose and (i + 1) % 100 == 0: + print(f" Validated {i + 1}/{samples_to_check} samples...", file=sys.stderr) + + # Summarize truncated errors + if missing_audio_count > 5: + result.errors.append( + f"... and {missing_audio_count - 5} more missing audio files " + f"({missing_audio_count} total)" + ) + if empty_text_count > 5: + result.warnings.append( + f"... and {empty_text_count - 5} more empty text entries " + f"({empty_text_count} total)" + ) + + return result + + +def print_validation_report(result: ValidationResult, manifest_path: str) -> None: + """Print a human-readable validation report to stderr.""" + print(f"\n{'=' * 60}", file=sys.stderr) + print(f" VoxCPM Training Data Validation Report", file=sys.stderr) + print(f"{'=' * 60}", file=sys.stderr) + print(f" Manifest : {manifest_path}", file=sys.stderr) + print(f" Samples : {result.valid_samples}/{result.total_samples} valid", file=sys.stderr) + + if result.has_ref_audio > 0: + print( + f" Ref Audio: {result.has_ref_audio} samples with reference audio", + file=sys.stderr, + ) + + # Audio duration statistics + if result.audio_durations: + durations = sorted(result.audio_durations) + total_hrs = sum(durations) / 3600 + print(f"\n Audio Duration Statistics:", file=sys.stderr) + print(f" Total : {total_hrs:.2f} hours", file=sys.stderr) + print( + f" Range : {durations[0]:.2f}s — {durations[-1]:.1f}s", + file=sys.stderr, + ) + print( + f" Mean : {sum(durations) / len(durations):.2f}s", + file=sys.stderr, + ) + median_idx = len(durations) // 2 + print(f" Median : {durations[median_idx]:.2f}s", file=sys.stderr) + + # Text length statistics + if result.text_lengths: + lengths = sorted(result.text_lengths) + print(f"\n Text Length Statistics (characters):", file=sys.stderr) + print( + f" Range : {lengths[0]} — {lengths[-1]}", + file=sys.stderr, + ) + print( + f" Mean : {sum(lengths) / len(lengths):.0f}", + file=sys.stderr, + ) + + # Errors + if result.errors: + print(f"\n ERRORS ({len(result.errors)}):", file=sys.stderr) + for err in result.errors[:20]: + print(f" x {err}", file=sys.stderr) + if len(result.errors) > 20: + print( + f" ... ({len(result.errors) - 20} more errors omitted)", + file=sys.stderr, + ) + + # Warnings + if result.warnings: + print(f"\n WARNINGS ({len(result.warnings)}):", file=sys.stderr) + for warn in result.warnings[:10]: + print(f" ! {warn}", file=sys.stderr) + if len(result.warnings) > 10: + print( + f" ... ({len(result.warnings) - 10} more warnings omitted)", + file=sys.stderr, + ) + + # Summary + print(f"\n{'=' * 60}", file=sys.stderr) + if result.is_valid: + print(" PASSED: Manifest is valid for training.", file=sys.stderr) + else: + print(" FAILED: Fix errors above before starting training.", file=sys.stderr) + print(f"{'=' * 60}\n", file=sys.stderr) diff --git a/tests/test_validate.py b/tests/test_validate.py new file mode 100644 index 0000000..06e1b8a --- /dev/null +++ b/tests/test_validate.py @@ -0,0 +1,252 @@ +"""Tests for the training data validation module.""" + +from __future__ import annotations + +import json +import os +import sys +import tempfile +import types +from pathlib import Path + +import pytest + +ROOT = Path(__file__).resolve().parents[1] + +# Stub voxcpm package so imports work without full dependencies +pkg = types.ModuleType("voxcpm") +pkg.__path__ = [str(ROOT / "src" / "voxcpm")] +sys.modules.setdefault("voxcpm", pkg) + +training_pkg = types.ModuleType("voxcpm.training") +training_pkg.__path__ = [str(ROOT / "src" / "voxcpm" / "training")] +sys.modules.setdefault("voxcpm.training", training_pkg) + +from voxcpm.training.validate import ValidationResult, validate_manifest + + +@pytest.fixture +def tmp_dir(): + with tempfile.TemporaryDirectory() as d: + yield Path(d) + + +def _create_wav(path: Path, duration_s: float = 1.0, sr: int = 16000): + """Create a minimal valid WAV file.""" + try: + import soundfile as sf + import numpy as np + + samples = int(duration_s * sr) + data = np.zeros(samples, dtype=np.float32) + sf.write(str(path), data, sr) + except ImportError: + # If soundfile is not available, create a minimal WAV header + import struct + + samples = int(duration_s * sr) + data_size = samples * 2 # 16-bit PCM + with open(path, "wb") as f: + f.write(b"RIFF") + f.write(struct.pack("= 2 + assert any("Invalid JSON" in e for e in result.errors) + + def test_missing_columns(self, tmp_dir): + manifest = tmp_dir / "missing.jsonl" + _write_manifest( + manifest, + [ + {"text": "hello"}, # missing audio + {"audio": "test.wav"}, # missing text + ], + ) + result = validate_manifest(str(manifest)) + assert len(result.errors) >= 2 + assert any("'audio'" in e for e in result.errors) + assert any("'text'" in e for e in result.errors) + + def test_missing_audio_file(self, tmp_dir): + manifest = tmp_dir / "missing_audio.jsonl" + _write_manifest( + manifest, + [{"text": "hello", "audio": "/nonexistent/audio.wav"}], + ) + result = validate_manifest(str(manifest)) + assert not result.is_valid + assert any("not found" in e for e in result.errors) + + def test_empty_text_warning(self, tmp_dir): + audio = tmp_dir / "audio.wav" + _create_wav(audio) + manifest = tmp_dir / "empty_text.jsonl" + _write_manifest( + manifest, + [{"text": "", "audio": str(audio)}], + ) + result = validate_manifest(str(manifest)) + assert len(result.warnings) > 0 + assert any("Empty" in w for w in result.warnings) + + def test_relative_audio_path(self, tmp_dir): + audio = tmp_dir / "audio.wav" + _create_wav(audio) + manifest = tmp_dir / "rel.jsonl" + _write_manifest( + manifest, + [{"text": "hello", "audio": "audio.wav"}], + ) + result = validate_manifest(str(manifest)) + assert result.valid_samples == 1 + assert result.is_valid + + def test_max_samples_limit(self, tmp_dir): + audio = tmp_dir / "audio.wav" + _create_wav(audio) + manifest = tmp_dir / "many.jsonl" + _write_manifest( + manifest, + [{"text": f"sample {i}", "audio": str(audio)} for i in range(100)], + ) + result = validate_manifest(str(manifest), max_samples=10) + assert result.total_samples == 10 + + def test_ref_audio_counted(self, tmp_dir): + audio = tmp_dir / "audio.wav" + ref = tmp_dir / "ref.wav" + _create_wav(audio) + _create_wav(ref) + manifest = tmp_dir / "ref.jsonl" + _write_manifest( + manifest, + [{"text": "hello", "audio": str(audio), "ref_audio": str(ref)}], + ) + result = validate_manifest(str(manifest)) + assert result.has_ref_audio == 1 + + def test_validation_result_properties(self): + r = ValidationResult(total_samples=5, valid_samples=5) + assert r.is_valid + + r2 = ValidationResult(total_samples=5, valid_samples=5, errors=["err"]) + assert not r2.is_valid + + r3 = ValidationResult(total_samples=0, valid_samples=0) + assert not r3.is_valid + + def test_invalid_audio_not_counted_as_valid(self, tmp_dir): + """A row with a bad audio path must not increment valid_samples.""" + manifest = tmp_dir / "bad_audio.jsonl" + _write_manifest( + manifest, + [{"text": "hello", "audio": "/nonexistent/audio.wav"}], + ) + result = validate_manifest(str(manifest)) + assert result.total_samples == 1 + assert result.valid_samples == 0 + assert not result.is_valid + assert any("not found" in e for e in result.errors) + + def test_sample_rate_mismatch(self, tmp_dir): + """A file with a different sample rate should be reported as an error.""" + try: + import soundfile as sf + import numpy as np + except ImportError: + pytest.skip("soundfile not available") + + audio = tmp_dir / "audio_8k.wav" + import numpy as np + samples = np.zeros(8000, dtype=np.float32) + sf.write(str(audio), samples, 8000) + + manifest = tmp_dir / "sr_mismatch.jsonl" + _write_manifest(manifest, [{"text": "hello", "audio": str(audio)}]) + + result = validate_manifest(str(manifest), sample_rate=16000) + assert result.valid_samples == 0 + assert not result.is_valid + assert any("Sample rate mismatch" in e or "sample rate" in e.lower() for e in result.errors) + + def test_mixed_ref_audio_warns_for_each_missing(self, tmp_dir): + """Missing ref_audio entries should each generate a warning independently.""" + audio = tmp_dir / "audio.wav" + ref_good = tmp_dir / "ref_good.wav" + _create_wav(audio) + _create_wav(ref_good) + + manifest = tmp_dir / "mixed_ref.jsonl" + _write_manifest( + manifest, + [ + {"text": "row1", "audio": str(audio), "ref_audio": str(ref_good)}, + {"text": "row2", "audio": str(audio), "ref_audio": "/nonexistent/ref.wav"}, + ], + ) + result = validate_manifest(str(manifest)) + assert result.has_ref_audio == 1 + assert any("ref_audio file not found" in w for w in result.warnings) + + def test_cli_validate_exit_code(self, tmp_dir): + """validate subcommand must exit 1 on validation error (missing audio).""" + import subprocess + manifest = tmp_dir / "bad.jsonl" + _write_manifest(manifest, [{"text": "hi", "audio": "/nonexistent/x.wav"}]) + + proc = subprocess.run( + [sys.executable, "-m", "voxcpm.cli", "validate", "--manifest", str(manifest)], + capture_output=True, + text=True, + ) + assert proc.returncode == 1, f"Expected exit 1, got {proc.returncode}" + assert "FAILED" in proc.stderr or "Audio file not found" in proc.stderr