Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: twine upload dist/*
run: twine upload dist/*
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ run:
--queue-size 100

install:
pip install -e .
pip install -e .
5 changes: 3 additions & 2 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os

from .version import __version__

# Suppress transformers warnings
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

__all__ = ["__version__"]
__all__ = ["__version__"]
44 changes: 24 additions & 20 deletions app/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ async def chat_completions(
content=create_error_response(str(e)), status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)


@router.post("/v1/embeddings", response_model=None)
async def embeddings(
request: EmbeddingRequest, raw_request: Request
Expand Down Expand Up @@ -855,7 +856,6 @@ async def process_multimodal_request(
return JSONResponse(content=final_response.model_dump(exclude_none=True))



async def process_text_request(
handler: MLXLMHandler | MLXVLMHandler,
request: ChatCompletionRequest,
Expand Down Expand Up @@ -990,10 +990,12 @@ def format_final_response(
request_id=request_id,
)


# =============================================================================
# Responses API Handlers
# =============================================================================


def _normalize_responses_item(item: Any) -> dict[str, Any]:
"""Normalize TypedDict/BaseModel response item to a plain dictionary."""
if isinstance(item, dict):
Expand All @@ -1014,7 +1016,9 @@ def _serialize_responses_tool_output(output: Any) -> str:
text_parts: list[str] = []
for output_item in output:
normalized = _normalize_responses_item(output_item)
if normalized.get("type") in {"input_text", "output_text", "text"} and normalized.get("text"):
if normalized.get("type") in {"input_text", "output_text", "text"} and normalized.get(
"text"
):
text_parts.append(str(normalized["text"]))
if text_parts:
return "\n".join(text_parts)
Expand Down Expand Up @@ -1104,9 +1108,7 @@ def _convert_responses_tool_choice(tool_choice: Any) -> Any:
return "auto"


def convert_responses_request_to_chat_request(
request: ResponsesRequest
) -> ChatCompletionRequest:
def convert_responses_request_to_chat_request(request: ResponsesRequest) -> ChatCompletionRequest:
"""Convert a Responses request into a ChatCompletionRequest with full turn history."""
chat_messages: list[Message] = []
pending_tool_calls: list[ChatCompletionMessageToolCall] = []
Expand Down Expand Up @@ -1188,7 +1190,9 @@ def flush_pending_tool_calls() -> None:
if item_type in {"input_text", "text"}:
text = item.get("text")
if text:
pending_user_parts.append(ChatCompletionContentPartText(type="text", text=str(text)))
pending_user_parts.append(
ChatCompletionContentPartText(type="text", text=str(text))
)
elif item_type in {"input_image", "image_url"} and item.get("image_url"):
pending_user_parts.append(
ChatCompletionContentPartImage(
Expand Down Expand Up @@ -1253,10 +1257,9 @@ def flush_pending_tool_calls() -> None:

return ChatCompletionRequest(**chat_request_payload)


def format_final_responses_response(
response: str | dict[str, Any],
request: ResponsesRequest,
usage: UsageInfo | None = None
response: str | dict[str, Any], request: ResponsesRequest, usage: UsageInfo | None = None
) -> ResponsesResponse:
"""Format the final non-streaming response."""
response_payload: dict[str, Any]
Expand Down Expand Up @@ -1391,6 +1394,7 @@ def refine_responses_request(
request.model = Config.TEXT_MODEL
return request


async def handle_responses_stream_response(
generator: AsyncGenerator[Any, None],
request: ResponsesRequest,
Expand Down Expand Up @@ -1438,9 +1442,7 @@ def _create_base_response(
text_val = None
if request.text:
text_val = (
request.text.model_dump()
if hasattr(request.text, "model_dump")
else request.text
request.text.model_dump() if hasattr(request.text, "model_dump") else request.text
)
return {
"id": resp_id,
Expand Down Expand Up @@ -1666,7 +1668,12 @@ def _open_message_item() -> str:
{
"id": msg_item_id,
"content": [
{"annotations": [], "text": full_text, "type": "output_text", "logprobs": None}
{
"annotations": [],
"text": full_text,
"type": "output_text",
"logprobs": None,
}
],
"role": "assistant",
"status": "completed",
Expand Down Expand Up @@ -1707,9 +1714,9 @@ def _open_message_item() -> str:
f"data: {json.dumps({'response': final_response_obj, 'sequence_number': _next_seq(), 'type': 'response.completed'})}\n\n"
)


async def process_text_responses_request(
handler: MLXLMHandler,
request: ResponsesRequest
handler: MLXLMHandler, request: ResponsesRequest
) -> ResponsesResponse | StreamingResponse | JSONResponse:
"""Handle text-only Responses API requests."""
refined_request = refine_responses_request(request, handler)
Expand Down Expand Up @@ -1764,13 +1771,10 @@ async def process_multimodal_responses_request(
result = await handler.generate_multimodal_response(chat_request)
response_data = result.get("response")
usage = result.get("usage")
final_response = format_final_responses_response(
response_data,
refined_request,
usage
)
final_response = format_final_responses_response(response_data, refined_request, usage)
return JSONResponse(content=final_response.model_dump(exclude_none=True))


@router.post("/v1/responses", response_model=None)
async def responses_endpoint(
request: ResponsesRequest, raw_request: Request
Expand Down
54 changes: 27 additions & 27 deletions app/core/audio_processor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os
import gc
import asyncio
from typing import List
import gc
import os

from .base_processor import BaseProcessor


class AudioProcessor(BaseProcessor):
"""Audio processor for handling audio files with caching and validation."""

def __init__(self, max_workers: int = 4, cache_size: int = 1000):
super().__init__(max_workers, cache_size)
# Supported audio formats
self._supported_formats = {'.mp3', '.wav'}
self._supported_formats = {".mp3", ".wav"}

def _get_media_format(self, media_url: str, data: bytes = None) -> str:
"""Determine audio format from URL or data."""
Expand All @@ -20,50 +20,50 @@ def _get_media_format(self, media_url: str, data: bytes = None) -> str:
mime_type = media_url.split(";")[0].split(":")[1]
if "mp3" in mime_type or "mpeg" in mime_type:
return "mp3"
elif "wav" in mime_type:
if "wav" in mime_type:
return "wav"
elif "m4a" in mime_type or "mp4" in mime_type:
if "m4a" in mime_type or "mp4" in mime_type:
return "m4a"
elif "ogg" in mime_type:
if "ogg" in mime_type:
return "ogg"
elif "flac" in mime_type:
if "flac" in mime_type:
return "flac"
elif "aac" in mime_type:
if "aac" in mime_type:
return "aac"
else:
# Extract format from file extension
ext = os.path.splitext(media_url.lower())[1]
if ext in self._supported_formats:
return ext[1:] # Remove the dot

# Default to mp3 if format cannot be determined
return "mp3"

def _validate_media_data(self, data: bytes) -> bool:
"""Basic validation of audio data."""
if len(data) < 100: # Too small to be a valid audio file
return False

# Check for common audio file signatures
audio_signatures = [
b'ID3', # MP3 with ID3 tag
b'\xff\xfb', # MP3 frame header
b'\xff\xf3', # MP3 frame header
b'\xff\xf2', # MP3 frame header
b'RIFF', # WAV/AVI
b'OggS', # OGG
b'fLaC', # FLAC
b'\x00\x00\x00\x20ftypM4A', # M4A
b"ID3", # MP3 with ID3 tag
b"\xff\xfb", # MP3 frame header
b"\xff\xf3", # MP3 frame header
b"\xff\xf2", # MP3 frame header
b"RIFF", # WAV/AVI
b"OggS", # OGG
b"fLaC", # FLAC
b"\x00\x00\x00\x20ftypM4A", # M4A
]

for sig in audio_signatures:
if data.startswith(sig):
return True

# Check for WAV format (RIFF header might be at different position)
if b'WAVE' in data[:50]:
if b"WAVE" in data[:50]:
return True

return True # Allow unknown formats to pass through

def _get_timeout(self) -> int:
Expand All @@ -76,7 +76,7 @@ def _get_max_file_size(self) -> int:

def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
"""Process audio data and save to cached path."""
with open(cached_path, 'wb') as f:
with open(cached_path, "wb") as f:
f.write(data)
self._cleanup_old_files()
return cached_path
Expand All @@ -89,10 +89,10 @@ async def process_audio_url(self, audio_url: str) -> str:
"""Process a single audio URL and return path to cached file."""
return await self._process_single_media(audio_url)

async def process_audio_urls(self, audio_urls: List[str]) -> List[str]:
async def process_audio_urls(self, audio_urls: list[str]) -> list[str]:
"""Process multiple audio URLs and return paths to cached files."""
tasks = [self.process_audio_url(url) for url in audio_urls]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Force garbage collection after batch processing
gc.collect()
return results
return results
Loading
Loading