Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion mlx_audio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ class TranscriptionRequest(BaseModel):
context: str | None = None
prefill_step_size: int = 2048
text: str | None = None
word_timestamps: bool = False
timestamp_granularities: Optional[str] = None


class SeparationResponse(BaseModel):
Expand Down Expand Up @@ -229,6 +231,9 @@ def _load_model_for_inference(model_name: str):
return model_provider.load_model(model_name)


_STT_EXTRA_KWARGS = {"word_timestamps", "timestamp_granularities"}


class STTExecutionAdapter(BaseModelExecutionAdapter):
def run_serial(self, request: InferenceRequest) -> None:
payload: TranscriptionTaskPayload = request.payload
Expand All @@ -246,7 +251,7 @@ def run_serial(self, request: InferenceRequest) -> None:
gen_kwargs = {
key: value
for key, value in gen_kwargs.items()
if key in signature.parameters
if key in signature.parameters or key in _STT_EXTRA_KWARGS
}

result = stt_model.generate(tmp_path, **gen_kwargs)
Expand Down Expand Up @@ -931,6 +936,8 @@ async def stt_transcriptions(
prefill_step_size: int = Form(2048),
text: Optional[str] = Form(None),
response_format: str = Form("ndjson"),
word_timestamps: bool = Form(False),
timestamp_granularities: Optional[str] = Form(None),
):
"""Transcribe audio using an STT model.

Expand Down Expand Up @@ -960,6 +967,8 @@ async def stt_transcriptions(
context=context,
prefill_step_size=prefill_step_size,
text=text,
word_timestamps=word_timestamps,
timestamp_granularities=timestamp_granularities,
)
data = await file.read()
tmp = io.BytesIO(data)
Expand Down
97 changes: 97 additions & 0 deletions mlx_audio/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,100 @@ def test_realtime_ws_streaming_disabled_fallback(client, mock_model_provider):
# Should have at least one legacy text message
text_msgs = [m for m in messages if "text" in m and "type" not in m]
assert len(text_msgs) >= 1, f"Expected legacy text message, got: {messages}"


# ---------------------------------------------------------------------------
# word_timestamps form field tests
# ---------------------------------------------------------------------------


def test_transcription_request_word_timestamps_defaults():
"""TranscriptionRequest defaults word_timestamps=False, timestamp_granularities=None."""
from mlx_audio.server import TranscriptionRequest

req = TranscriptionRequest(model="test-model")
assert req.word_timestamps is False
assert req.timestamp_granularities is None


def test_transcription_request_word_timestamps_accepted():
"""TranscriptionRequest accepts word_timestamps=True."""
from mlx_audio.server import TranscriptionRequest

req = TranscriptionRequest(model="test-model", word_timestamps=True, timestamp_granularities="word")
assert req.word_timestamps is True
assert req.timestamp_granularities == "word"


def test_stt_word_timestamps_passed_to_generate(client, mock_model_provider):
"""word_timestamps=true form field reaches stt_model.generate() as a kwarg.

The STTExecutionAdapter allowlist (_STT_EXTRA_KWARGS) must pass word_timestamps
through even when it isn't declared in the model's generate() signature.
"""
captured_kwargs: dict = {}

def mock_generate(path, **kwargs):
captured_kwargs.update(kwargs)
return {"text": "hello", "segments": [], "language": "en"}

mock_stt_model = MagicMock()
mock_stt_model.generate = mock_generate
mock_model_provider.load_model = MagicMock(return_value=mock_stt_model)

response = client.post(
"/v1/audio/transcriptions",
files={"file": ("test.mp3", _make_transcription_audio_buffer(), "audio/mp3")},
data={
"model": "test_stt_model",
"response_format": "verbose_json",
"word_timestamps": "true",
},
)

assert response.status_code == 200
assert captured_kwargs.get("word_timestamps") is True


def test_stt_word_timestamps_verbose_json_words_passthrough(client, mock_model_provider):
"""verbose_json response includes words[] from the model when word_timestamps=True."""
full_payload = {
"text": "Hello world.",
"language": "en",
"segments": [
{
"id": 0,
"text": "Hello world.",
"start": 0.0,
"end": 1.0,
"words": [
{"word": "Hello", "start": 0.0, "end": 0.5, "probability": 0.99},
{"word": "world.", "start": 0.5, "end": 1.0, "probability": 0.98},
],
}
],
}

mock_stt_model = MagicMock()
mock_stt_model.generate = MagicMock(return_value=full_payload)
mock_model_provider.load_model = MagicMock(return_value=mock_stt_model)

response = client.post(
"/v1/audio/transcriptions",
files={"file": ("test.mp3", _make_transcription_audio_buffer(), "audio/mp3")},
data={
"model": "test_stt_model",
"response_format": "verbose_json",
"word_timestamps": "true",
},
)

assert response.status_code == 200
body = response.json()
assert body["text"] == "Hello world."
segments = body.get("segments", [])
assert len(segments) == 1
words = segments[0].get("words", [])
assert len(words) == 2
assert words[0]["word"] == "Hello"
assert words[1]["word"] == "world."
Loading