Skip to content

Commit f006ee2

Browse files
nadavoxclaude
andcommitted
fix: address self-review and CodeRabbit findings
Critical fixes: - Add tts_backend parameter to create_voice_prompt_for_profile (was crashing Hebrew generation) - Fix double-wrapped language tokens in Whisper suppress_tokens (was producing malformed IDs) - Add threading lock to torch.load monkey-patch for thread safety High priority: - Deduplicate STT_MODEL_MAP into backends/__init__.py (was copy-pasted in mlx + pytorch) - Scope trim_tts_output to Chatterbox only (was aggressively trimming Qwen output) - Expand TranscriptionRequest language pattern to all 11 supported languages - Add Chatterbox sub-dependencies to requirements.txt (conformer, diffusers, etc.) - Read sample_rate from Chatterbox model object instead of hardcoding 24000 Cleanup: - Remove duplicate import asyncio - Remove console.log debug statements from client.ts triggerModelDownload - Add CUDA cache cleanup to Chatterbox unload_model - Add chatterbox unload to shutdown handler - Fix unused variable warnings (sr -> _sr) - Fix f-strings without placeholders Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c05bf3e commit f006ee2

10 files changed

Lines changed: 56 additions & 49 deletions

File tree

app/src/lib/api/client.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,10 @@ class ApiClient {
310310
}
311311

312312
async triggerModelDownload(modelName: string): Promise<{ message: string }> {
313-
console.log('[API] triggerModelDownload called for:', modelName, 'at', new Date().toISOString());
314-
const result = await this.request<{ message: string }>('/models/download', {
313+
return this.request<{ message: string }>('/models/download', {
315314
method: 'POST',
316315
body: JSON.stringify({ model_name: modelName } as ModelDownloadRequest),
317316
});
318-
console.log('[API] triggerModelDownload response:', result);
319-
return result;
320317
}
321318

322319
async deleteModel(modelName: string): Promise<{ message: string }> {

backend/backends/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010

1111
from ..platform_detect import get_backend_type
1212

13+
# Shared model name mapping for STT backends (MLX + PyTorch).
14+
# Maps short model size keys to HuggingFace repo IDs.
15+
STT_MODEL_MAP = {
16+
"base": "openai/whisper-base",
17+
"small": "openai/whisper-small",
18+
"medium": "openai/whisper-medium",
19+
"large": "openai/whisper-large",
20+
"ivrit-v3": "ivrit-ai/whisper-large-v3",
21+
"ivrit-v3-turbo": "ivrit-ai/whisper-large-v3-turbo",
22+
}
23+
1324

1425
@runtime_checkable
1526
class TTSBackend(Protocol):

backend/backends/chatterbox_backend.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,20 @@ def _load_model_sync(self, model_size: str):
136136
# The multilingual model's .pt files were saved on CUDA and
137137
# from_local() doesn't pass map_location, so loading on CPU fails.
138138
if device == "cpu":
139+
import threading
139140
_orig_torch_load = torch.load
141+
_load_lock = threading.Lock()
140142

141143
def _patched_load(*args, **kwargs):
142144
kwargs.setdefault("map_location", "cpu")
143145
return _orig_torch_load(*args, **kwargs)
144146

145-
torch.load = _patched_load
146-
try:
147-
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
148-
finally:
149-
torch.load = _orig_torch_load
147+
with _load_lock:
148+
torch.load = _patched_load
149+
try:
150+
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
151+
finally:
152+
torch.load = _orig_torch_load
150153
else:
151154
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
152155

@@ -171,8 +174,8 @@ def _patched_load(*args, **kwargs):
171174

172175
except ImportError as e:
173176
print(
174-
f"Error: chatterbox-tts package not found. "
175-
f"Install with: pip install chatterbox-tts"
177+
"Error: chatterbox-tts package not found. "
178+
"Install with: pip install chatterbox-tts"
176179
)
177180
progress_manager = get_progress_manager()
178181
task_manager = get_task_manager()
@@ -218,9 +221,13 @@ def _patched_add_hebrew_diacritics(text: str) -> str:
218221
def unload_model(self) -> None:
219222
"""Unload model to free memory."""
220223
if self.model is not None:
224+
device = self._device
221225
del self.model
222226
self.model = None
223227
self._device = None
228+
if device == "cuda":
229+
import torch
230+
torch.cuda.empty_cache()
224231
print("Chatterbox Multilingual TTS model unloaded")
225232

226233
async def create_voice_prompt(
@@ -250,7 +257,7 @@ async def combine_voice_prompts(
250257
combined_audio = []
251258

252259
for audio_path in audio_paths:
253-
audio, sr = load_audio(audio_path)
260+
audio, _sr = load_audio(audio_path)
254261
audio = normalize_audio(audio)
255262
combined_audio.append(audio)
256263

@@ -334,8 +341,7 @@ def _generate_sync():
334341
else:
335342
audio = np.asarray(wav, dtype=np.float32)
336343

337-
# Chatterbox default sample rate is 24000
338-
sample_rate = 24000
344+
sample_rate = getattr(self.model, 'sr', None) or getattr(self.model, 'sample_rate', 24000)
339345

340346
return audio, sample_rate
341347

backend/backends/mlx_backend.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,7 @@ async def generate_with_adapter(
400400
return await self.generate(text, voice_prompt, language, seed, instruct)
401401

402402

403-
STT_MODEL_MAP = {
404-
"base": "openai/whisper-base",
405-
"small": "openai/whisper-small",
406-
"medium": "openai/whisper-medium",
407-
"large": "openai/whisper-large",
408-
"ivrit-v3": "ivrit-ai/whisper-large-v3",
409-
"ivrit-v3-turbo": "ivrit-ai/whisper-large-v3-turbo",
410-
}
403+
from . import STT_MODEL_MAP
411404

412405

413406
class MLXSTTBackend:

backend/backends/pytorch_backend.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -369,14 +369,7 @@ def _generate_sync():
369369
return audio, sample_rate
370370

371371

372-
STT_MODEL_MAP = {
373-
"base": "openai/whisper-base",
374-
"small": "openai/whisper-small",
375-
"medium": "openai/whisper-medium",
376-
"large": "openai/whisper-large",
377-
"ivrit-v3": "ivrit-ai/whisper-large-v3",
378-
"ivrit-v3-turbo": "ivrit-ai/whisper-large-v3-turbo",
379-
}
372+
from . import STT_MODEL_MAP
380373

381374

382375
class PyTorchSTTBackend:
@@ -608,15 +601,13 @@ def _transcribe_sync():
608601
tokenizer = self.processor.tokenizer
609602
lang_token = f"<|{language}|>"
610603
if lang_token in tokenizer.get_vocab():
611-
lang_id = tokenizer.convert_tokens_to_ids(lang_token)
612604
# Suppress all other language tokens to prevent drift
613605
all_lang_tokens = [
614-
tokenizer.convert_tokens_to_ids(f"<|{lang}|>")
615-
for lang in tokenizer.additional_special_tokens
616-
if lang.startswith("<|") and lang.endswith("|>")
617-
and lang != lang_token and lang != "<|transcribe|>"
618-
and lang != "<|notimestamps|>"
619-
and tokenizer.convert_tokens_to_ids(lang) != tokenizer.unk_token_id
606+
tokenizer.convert_tokens_to_ids(tok)
607+
for tok in tokenizer.additional_special_tokens
608+
if tok.startswith("<|") and tok.endswith("|>")
609+
and tok not in (lang_token, "<|transcribe|>", "<|notimestamps|>")
610+
and tokenizer.convert_tokens_to_ids(tok) != tokenizer.unk_token_id
620611
]
621612
if all_lang_tokens:
622613
generate_kwargs["suppress_tokens"] = all_lang_tokens

backend/main.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import io
2020
from pathlib import Path
2121
import uuid
22-
import asyncio
2322
import signal
2423
import os
2524
from urllib.parse import quote
@@ -704,9 +703,10 @@ async def download_model_background():
704703
data.instruct,
705704
)
706705

707-
# Trim trailing silence/noise from TTS output (known Chatterbox issue)
708-
from .utils.audio import trim_tts_output
709-
audio = trim_tts_output(audio, sample_rate)
706+
# Trim trailing silence/noise from Chatterbox output (known hallucination issue)
707+
if data.language == "he":
708+
from .utils.audio import trim_tts_output
709+
audio = trim_tts_output(audio, sample_rate)
710710

711711
# Calculate duration
712712
duration = len(audio) / sample_rate
@@ -2010,6 +2010,7 @@ async def shutdown_event():
20102010
print("voicebox API shutting down...")
20112011
# Unload models to free memory
20122012
tts.unload_tts_model()
2013+
tts.unload_chatterbox_model()
20132014
transcribe.unload_whisper_model()
20142015

20152016

backend/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class HistoryListResponse(BaseModel):
108108

109109
class TranscriptionRequest(BaseModel):
110110
"""Request model for audio transcription."""
111-
language: Optional[str] = Field(None, pattern="^(en|zh|he)$")
111+
language: Optional[str] = Field(None, pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it|he)$")
112112

113113

114114
class TranscriptionResponse(BaseModel):

backend/profiles.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ async def create_voice_prompt_for_profile(
327327
profile_id: str,
328328
db: Session,
329329
use_cache: bool = True,
330+
tts_backend=None,
330331
) -> dict:
331332
"""
332333
Create a combined voice prompt from all samples in a profile.
@@ -335,6 +336,7 @@ async def create_voice_prompt_for_profile(
335336
profile_id: Profile ID
336337
db: Database session
337338
use_cache: Whether to use cached prompts
339+
tts_backend: Optional TTS backend override (e.g. Chatterbox for Hebrew)
338340
339341
Returns:
340342
Voice prompt dictionary
@@ -345,12 +347,12 @@ async def create_voice_prompt_for_profile(
345347
if not samples:
346348
raise ValueError(f"No samples found for profile {profile_id}")
347349

348-
tts_model = get_tts_model()
350+
backend = tts_backend or get_tts_model()
349351

350352
if len(samples) == 1:
351353
# Single sample - use directly
352354
sample = samples[0]
353-
voice_prompt, _ = await tts_model.create_voice_prompt(
355+
voice_prompt, _ = await backend.create_voice_prompt(
354356
sample.audio_path,
355357
sample.reference_text,
356358
use_cache=use_cache,
@@ -362,7 +364,7 @@ async def create_voice_prompt_for_profile(
362364
reference_texts = [s.reference_text for s in samples]
363365

364366
# Combine audio
365-
combined_audio, combined_text = await tts_model.combine_voice_prompts(
367+
combined_audio, combined_text = await backend.combine_voice_prompts(
366368
audio_paths,
367369
reference_texts,
368370
)

backend/requirements.txt

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@ transformers>=4.36.0
1313
accelerate>=0.26.0
1414
huggingface_hub>=0.20.0
1515
qwen-tts>=0.0.5
16-
# Hebrew TTS — install with: pip install chatterbox-tts --no-deps
17-
# then install missing sub-deps: pip install conformer diffusers omegaconf pykakasi resemble-perth s3tokenizer
18-
# (numpy constraint on PyPI is too strict; works fine with numpy 2.x in practice)
16+
# Hebrew TTS (Chatterbox)
17+
# Note: chatterbox-tts has a strict numpy<2 pin on PyPI but works fine with numpy 2.x.
18+
# If pip fails, install with: pip install chatterbox-tts --no-deps
1919
chatterbox-tts>=0.1.0
20+
conformer
21+
diffusers
22+
omegaconf
23+
pykakasi
24+
resemble-perth
25+
s3tokenizer
2026

2127
# Audio processing
2228
librosa>=0.10.0

backend/utils/audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def save_audio(
8282

8383
def prepare_for_transcription(
8484
audio: np.ndarray,
85-
sr: int,
85+
sr: int = 16000, # noqa: ARG001 — kept for API consistency
8686
) -> np.ndarray:
8787
"""
8888
Prepare audio for Whisper transcription.

0 commit comments

Comments
 (0)