@@ -136,17 +136,20 @@ def _load_model_sync(self, model_size: str):
136136 # The multilingual model's .pt files were saved on CUDA and
137137 # from_local() doesn't pass map_location, so loading on CPU fails.
138138 if device == "cpu" :
139+ import threading
139140 _orig_torch_load = torch .load
141+ _load_lock = threading .Lock ()
140142
141143 def _patched_load (* args , ** kwargs ):
142144 kwargs .setdefault ("map_location" , "cpu" )
143145 return _orig_torch_load (* args , ** kwargs )
144146
145- torch .load = _patched_load
146- try :
147- self .model = ChatterboxMultilingualTTS .from_pretrained (device = device )
148- finally :
149- torch .load = _orig_torch_load
147+ with _load_lock :
148+ torch .load = _patched_load
149+ try :
150+ self .model = ChatterboxMultilingualTTS .from_pretrained (device = device )
151+ finally :
152+ torch .load = _orig_torch_load
150153 else :
151154 self .model = ChatterboxMultilingualTTS .from_pretrained (device = device )
152155
@@ -171,8 +174,8 @@ def _patched_load(*args, **kwargs):
171174
172175 except ImportError as e :
173176 print (
174- f "Error: chatterbox-tts package not found. "
175- f "Install with: pip install chatterbox-tts"
177+ "Error: chatterbox-tts package not found. "
178+ "Install with: pip install chatterbox-tts"
176179 )
177180 progress_manager = get_progress_manager ()
178181 task_manager = get_task_manager ()
@@ -218,9 +221,13 @@ def _patched_add_hebrew_diacritics(text: str) -> str:
218221 def unload_model (self ) -> None :
219222 """Unload model to free memory."""
220223 if self .model is not None :
224+ device = self ._device
221225 del self .model
222226 self .model = None
223227 self ._device = None
228+ if device == "cuda" :
229+ import torch
230+ torch .cuda .empty_cache ()
224231 print ("Chatterbox Multilingual TTS model unloaded" )
225232
226233 async def create_voice_prompt (
@@ -250,7 +257,7 @@ async def combine_voice_prompts(
250257 combined_audio = []
251258
252259 for audio_path in audio_paths :
253- audio , sr = load_audio (audio_path )
260+ audio , _sr = load_audio (audio_path )
254261 audio = normalize_audio (audio )
255262 combined_audio .append (audio )
256263
@@ -334,8 +341,7 @@ def _generate_sync():
334341 else :
335342 audio = np .asarray (wav , dtype = np .float32 )
336343
337- # Chatterbox default sample rate is 24000
338- sample_rate = 24000
344+ sample_rate = getattr (self .model , 'sr' , None ) or getattr (self .model , 'sample_rate' , 24000 )
339345
340346 return audio , sample_rate
341347
0 commit comments