Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions backend/backends/mlx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
MLX backend implementation for TTS and STT using mlx-audio.
"""

from typing import Optional, List, Tuple
import asyncio
import logging
import numpy as np
from pathlib import Path

import numpy as np

logger = logging.getLogger(__name__)

# PATCH: Import and apply offline patch BEFORE any huggingface_hub usage
# This prevents mlx_audio from making network requests when models are cached
from ..utils.hf_offline_patch import patch_huggingface_hub_offline, ensure_original_qwen_config_cached
from ..utils.hf_offline_patch import ensure_original_qwen_config_cached, patch_huggingface_hub_offline

patch_huggingface_hub_offline()
ensure_original_qwen_config_cached()

from . import TTSBackend, STTBackend, LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
from .base import is_model_cached, combine_voice_prompts as _combine_voice_prompts, model_load_progress
from ..utils.cache import get_cache_key, get_cached_voice_prompt, cache_voice_prompt
from ..utils.cache import cache_voice_prompt, get_cache_key, get_cached_voice_prompt
from . import LANGUAGE_CODE_TO_NAME, WHISPER_HF_REPOS
from .base import combine_voice_prompts as _combine_voice_prompts, is_model_cached, model_load_progress


class MLXTTSBackend:
Expand Down Expand Up @@ -63,7 +63,7 @@ def _is_model_cached(self, model_size: str) -> bool:
weight_extensions=(".safetensors", ".bin", ".npz"),
)

async def load_model_async(self, model_size: Optional[str] = None):
async def load_model_async(self, model_size: str | None = None):
"""
Lazy load the MLX TTS model.

Expand Down Expand Up @@ -100,6 +100,19 @@ def _load_model_sync(self, model_size: str):

self.model = load(model_path)

import inspect

self._supports_ref_audio = "ref_audio" in inspect.signature(self.model.generate).parameters

# Warm up Metal JIT kernels — first inference compiles shaders, shift cost to load time
try:
logger.info("Warming up Metal kernels...")
for _ in self.model.generate("Hello.", lang_code="english"):
break # one token is enough to trigger compilation
logger.info("Metal warmup complete")
except Exception as e:
logger.warning("Warmup failed (non-fatal): %s", e)

self._current_model_size = model_size
self.model_size = model_size
logger.info("MLX TTS model %s loaded successfully", model_size)
Expand All @@ -117,7 +130,7 @@ async def create_voice_prompt(
audio_path: str,
reference_text: str,
use_cache: bool = True,
) -> Tuple[dict, bool]:
) -> tuple[dict, bool]:
"""
Create voice prompt from reference audio.

Expand Down Expand Up @@ -145,9 +158,8 @@ async def create_voice_prompt(
cached_audio_path = cached_prompt.get("ref_audio") or cached_prompt.get("ref_audio_path")
if cached_audio_path and Path(cached_audio_path).exists():
return cached_prompt, True
else:
# Cached file no longer exists, invalidate cache
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)
# Cached file no longer exists, invalidate cache
logger.warning("Cached audio file not found: %s, regenerating prompt", cached_audio_path)

# MLX voice prompt format - store audio path and text
# The model will process this during generation
Expand All @@ -171,9 +183,9 @@ async def generate(
text: str,
voice_prompt: dict,
language: str = "en",
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
seed: int | None = None,
instruct: str | None = None,
) -> tuple[np.ndarray, int]:
"""
Generate audio from text using voice prompt.

Expand Down Expand Up @@ -223,11 +235,8 @@ def _generate_sync():
# legitimate metadata calls during generation.
try:
if ref_audio:
# Check if generate accepts ref_audio parameter
import inspect

sig = inspect.signature(self.model.generate)
if "ref_audio" in sig.parameters:
# Use cached capability flag set at model load time
if self._supports_ref_audio:
# Generate with voice cloning
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
Expand Down Expand Up @@ -279,7 +288,7 @@ def _is_model_cached(self, model_size: str) -> bool:
hf_repo = WHISPER_HF_REPOS.get(model_size, f"openai/whisper-{model_size}")
return is_model_cached(hf_repo, weight_extensions=(".safetensors", ".bin", ".npz"))

async def load_model_async(self, model_size: Optional[str] = None):
async def load_model_async(self, model_size: str | None = None):
"""
Lazy load the MLX Whisper model.

Expand Down Expand Up @@ -324,8 +333,8 @@ def unload_model(self):
async def transcribe(
self,
audio_path: str,
language: Optional[str] = None,
model_size: Optional[str] = None,
language: str | None = None,
model_size: str | None = None,
) -> str:
"""
Transcribe audio to text.
Expand Down Expand Up @@ -356,12 +365,11 @@ def _transcribe_sync():
# Extract text from result
if isinstance(result, str):
return result.strip()
elif isinstance(result, dict):
if isinstance(result, dict):
return result.get("text", "").strip()
elif hasattr(result, "text"):
if hasattr(result, "text"):
return result.text.strip()
else:
return str(result).strip()
return str(result).strip()

# Run blocking transcription in thread pool
return await asyncio.to_thread(_transcribe_sync)
22 changes: 10 additions & 12 deletions backend/routes/health.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
"""Health and infrastructure endpoints."""

import asyncio
import contextlib
import os
import signal
from pathlib import Path

import torch
from fastapi import APIRouter, Depends
from fastapi import APIRouter
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session

from .. import config, models
from ..services import tts
from ..database import get_db
from ..utils.platform_detect import get_backend_type

router = APIRouter()
Expand Down Expand Up @@ -40,7 +39,7 @@ async def shutdown_async():
await asyncio.sleep(0.1)
os.kill(os.getpid(), signal.SIGTERM)

asyncio.create_task(shutdown_async())
asyncio.create_task(shutdown_async()) # noqa: RUF006 — fire-and-forget shutdown
return {"message": "Shutting down..."}


Expand All @@ -56,9 +55,10 @@ async def watchdog_disable():
@router.get("/health", response_model=models.HealthResponse)
async def health():
"""Health check endpoint."""
from huggingface_hub import constants as hf_constants
from pathlib import Path

from huggingface_hub import constants as hf_constants

tts_model = tts.get_tts_model()
backend_type = get_backend_type()

Expand Down Expand Up @@ -117,10 +117,8 @@ async def health():
if has_cuda:
vram_used = torch.cuda.memory_allocated() / 1024 / 1024
elif has_xpu:
try:
with contextlib.suppress(Exception): # memory_allocated() may not be available on all IPEX versions
vram_used = torch.xpu.memory_allocated() / 1024 / 1024
except Exception:
pass # memory_allocated() may not be available on all IPEX versions

model_loaded = False
model_size = None
Expand Down Expand Up @@ -175,7 +173,9 @@ async def health():
backend_type=backend_type,
backend_variant=os.environ.get(
"VOICEBOX_BACKEND_VARIANT",
"cuda" if torch.cuda.is_available() else ("xpu" if has_xpu else "cpu"),
"cuda"
if torch.cuda.is_available()
else ("xpu" if has_xpu else ("metal" if backend_type == "mlx" else "cpu")),
),
gpu_compatibility_warning=gpu_compat_warning,
)
Expand Down Expand Up @@ -211,10 +211,8 @@ async def filesystem_health():
except OSError as e:
error = str(e)
finally:
try:
with contextlib.suppress(Exception):
probe.unlink(missing_ok=True)
except Exception:
pass
else:
error = "Directory does not exist"

Expand Down