From b9c3b8de99909fe7dbc84631cd2837c9da4dcc14 Mon Sep 17 00:00:00 2001 From: orbitalquark <70453897+orbitalquark@users.noreply.github.com> Date: Sun, 5 Apr 2026 12:09:54 -0400 Subject: [PATCH] Register server disconnects while streaming TTS audio. No need to continue generating for clients that have hung up. This reduces latency processing the next request. --- mlx_audio/server.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mlx_audio/server.py b/mlx_audio/server.py index 39e9dae10..7f8a4bf3f 100644 --- a/mlx_audio/server.py +++ b/mlx_audio/server.py @@ -29,6 +29,7 @@ File, Form, HTTPException, + Request, Response, UploadFile, WebSocket, @@ -260,7 +261,7 @@ async def remove_model(model_name: str): raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") -async def generate_audio(model, payload: SpeechRequest): +async def generate_audio(model, payload: SpeechRequest, request: Request): # Load reference audio if provided ref_audio = payload.ref_audio audio_chunks = [] @@ -300,6 +301,10 @@ async def generate_audio(model, payload: SpeechRequest): verbose=payload.verbose, ): + if await request.is_disconnected(): + mx.clear_cache() + return + if payload.stream: buffer = io.BytesIO() audio_write( @@ -311,6 +316,8 @@ async def generate_audio(model, payload: SpeechRequest): if sample_rate is None: sample_rate = result.sample_rate + await asyncio.sleep(0) # register any disconnects + if payload.stream: return @@ -324,11 +331,11 @@ async def generate_audio(model, payload: SpeechRequest): @app.post("/v1/audio/speech") -async def tts_speech(payload: SpeechRequest): +async def tts_speech(payload: SpeechRequest, request: Request): """Generate speech audio following the OpenAI text-to-speech API.""" model = model_provider.load_model(payload.model) return StreamingResponse( - generate_audio(model, payload), + generate_audio(model, payload, request), media_type=f"audio/{payload.response_format}", headers={ "Content-Disposition": f"attachment; filename=speech.{payload.response_format}"