From df639b5ed638bdd41c04ae76bb7cd6a6eb57db25 Mon Sep 17 00:00:00 2001 From: Szymon Stasik Date: Mon, 11 May 2026 19:46:08 +0000 Subject: [PATCH] feat(stt): expose word_timestamps on /v1/audio/transcriptions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `word_timestamps: bool = Form(False)` and `timestamp_granularities: Optional[str] = Form(None)` form fields to the `/v1/audio/transcriptions` route handler and the `TranscriptionRequest` Pydantic model that backs it. `mlx_whisper.transcribe()` already accepts `word_timestamps` and returns `segments[].words[] = {start, end, probability, word}` when it is set, but the server's `STTExecutionAdapter` kwarg-filter step was silently dropping it because it used a strict signature-parameter allowlist. The fix adds `_STT_EXTRA_KWARGS = {"word_timestamps", "timestamp_granularities"}` — a small explicit allowlist that bypasses the signature check for these two fields — so they always reach the underlying model regardless of how the model's `generate()` signature is declared. The `word_timestamps` field works with any mlx-whisper model (fp16, quantized q4, etc.) because the kwarg is handled by `mlx_whisper.transcribe` itself, not by a model-specific code path. No change to the response-shaping layer is needed: `verbose_json` already returns the full model payload unchanged, so `segments[].words[]` is included automatically once the kwarg reaches the model. New form fields are optional and default to `False`/`None`, so all existing callers are unaffected (no breaking change). Tests added to `mlx_audio/tests/test_server.py`: - Unit test: `TranscriptionRequest` defaults + field acceptance - Integration test: `word_timestamps=true` reaches `generate()` kwargs - Integration test: `verbose_json` response includes `words[]` from model --- mlx_audio/server.py | 11 +++- mlx_audio/tests/test_server.py | 97 ++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/mlx_audio/server.py b/mlx_audio/server.py index 9bfe5692..8333108b 100644 --- a/mlx_audio/server.py +++ b/mlx_audio/server.py @@ -188,6 +188,8 @@ class TranscriptionRequest(BaseModel): context: str | None = None prefill_step_size: int = 2048 text: str | None = None + word_timestamps: bool = False + timestamp_granularities: Optional[str] = None class SeparationResponse(BaseModel): @@ -229,6 +231,9 @@ def _load_model_for_inference(model_name: str): return model_provider.load_model(model_name) +_STT_EXTRA_KWARGS = {"word_timestamps", "timestamp_granularities"} + + class STTExecutionAdapter(BaseModelExecutionAdapter): def run_serial(self, request: InferenceRequest) -> None: payload: TranscriptionTaskPayload = request.payload @@ -246,7 +251,7 @@ def run_serial(self, request: InferenceRequest) -> None: gen_kwargs = { key: value for key, value in gen_kwargs.items() - if key in signature.parameters + if key in signature.parameters or key in _STT_EXTRA_KWARGS } result = stt_model.generate(tmp_path, **gen_kwargs) @@ -931,6 +936,8 @@ async def stt_transcriptions( prefill_step_size: int = Form(2048), text: Optional[str] = Form(None), response_format: str = Form("ndjson"), + word_timestamps: bool = Form(False), + timestamp_granularities: Optional[str] = Form(None), ): """Transcribe audio using an STT model. @@ -960,6 +967,8 @@ async def stt_transcriptions( context=context, prefill_step_size=prefill_step_size, text=text, + word_timestamps=word_timestamps, + timestamp_granularities=timestamp_granularities, ) data = await file.read() tmp = io.BytesIO(data) diff --git a/mlx_audio/tests/test_server.py b/mlx_audio/tests/test_server.py index cb899978..fc280fc9 100644 --- a/mlx_audio/tests/test_server.py +++ b/mlx_audio/tests/test_server.py @@ -474,3 +474,100 @@ def test_realtime_ws_streaming_disabled_fallback(client, mock_model_provider): # Should have at least one legacy text message text_msgs = [m for m in messages if "text" in m and "type" not in m] assert len(text_msgs) >= 1, f"Expected legacy text message, got: {messages}" + + +# --------------------------------------------------------------------------- +# word_timestamps form field tests +# --------------------------------------------------------------------------- + + +def test_transcription_request_word_timestamps_defaults(): + """TranscriptionRequest defaults word_timestamps=False, timestamp_granularities=None.""" + from mlx_audio.server import TranscriptionRequest + + req = TranscriptionRequest(model="test-model") + assert req.word_timestamps is False + assert req.timestamp_granularities is None + + +def test_transcription_request_word_timestamps_accepted(): + """TranscriptionRequest accepts word_timestamps=True.""" + from mlx_audio.server import TranscriptionRequest + + req = TranscriptionRequest(model="test-model", word_timestamps=True, timestamp_granularities="word") + assert req.word_timestamps is True + assert req.timestamp_granularities == "word" + + +def test_stt_word_timestamps_passed_to_generate(client, mock_model_provider): + """word_timestamps=true form field reaches stt_model.generate() as a kwarg. + + The STTExecutionAdapter allowlist (_STT_EXTRA_KWARGS) must pass word_timestamps + through even when it isn't declared in the model's generate() signature. + """ + captured_kwargs: dict = {} + + def mock_generate(path, **kwargs): + captured_kwargs.update(kwargs) + return {"text": "hello", "segments": [], "language": "en"} + + mock_stt_model = MagicMock() + mock_stt_model.generate = mock_generate + mock_model_provider.load_model = MagicMock(return_value=mock_stt_model) + + response = client.post( + "/v1/audio/transcriptions", + files={"file": ("test.mp3", _make_transcription_audio_buffer(), "audio/mp3")}, + data={ + "model": "test_stt_model", + "response_format": "verbose_json", + "word_timestamps": "true", + }, + ) + + assert response.status_code == 200 + assert captured_kwargs.get("word_timestamps") is True + + +def test_stt_word_timestamps_verbose_json_words_passthrough(client, mock_model_provider): + """verbose_json response includes words[] from the model when word_timestamps=True.""" + full_payload = { + "text": "Hello world.", + "language": "en", + "segments": [ + { + "id": 0, + "text": "Hello world.", + "start": 0.0, + "end": 1.0, + "words": [ + {"word": "Hello", "start": 0.0, "end": 0.5, "probability": 0.99}, + {"word": "world.", "start": 0.5, "end": 1.0, "probability": 0.98}, + ], + } + ], + } + + mock_stt_model = MagicMock() + mock_stt_model.generate = MagicMock(return_value=full_payload) + mock_model_provider.load_model = MagicMock(return_value=mock_stt_model) + + response = client.post( + "/v1/audio/transcriptions", + files={"file": ("test.mp3", _make_transcription_audio_buffer(), "audio/mp3")}, + data={ + "model": "test_stt_model", + "response_format": "verbose_json", + "word_timestamps": "true", + }, + ) + + assert response.status_code == 200 + body = response.json() + assert body["text"] == "Hello world." + segments = body.get("segments", []) + assert len(segments) == 1 + words = segments[0].get("words", []) + assert len(words) == 2 + assert words[0]["word"] == "Hello" + assert words[1]["word"] == "world."