diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index a6a7bc65..cd40bb4c 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -29,4 +29,4 @@ jobs:
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
- run: twine upload dist/*
\ No newline at end of file
+ run: twine upload dist/*
diff --git a/Makefile b/Makefile
index 0398c325..519efd24 100644
--- a/Makefile
+++ b/Makefile
@@ -7,4 +7,4 @@ run:
--queue-size 100
install:
- pip install -e .
\ No newline at end of file
+ pip install -e .
diff --git a/app/__init__.py b/app/__init__.py
index 0b90d88e..6e5d4622 100644
--- a/app/__init__.py
+++ b/app/__init__.py
@@ -1,7 +1,8 @@
import os
+
from .version import __version__
# Suppress transformers warnings
-os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
+os.environ["TRANSFORMERS_VERBOSITY"] = "error"
-__all__ = ["__version__"]
\ No newline at end of file
+__all__ = ["__version__"]
diff --git a/app/api/endpoints.py b/app/api/endpoints.py
index 91d0d4c7..fc247a84 100644
--- a/app/api/endpoints.py
+++ b/app/api/endpoints.py
@@ -419,6 +419,7 @@ async def chat_completions(
content=create_error_response(str(e)), status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
+
@router.post("/v1/embeddings", response_model=None)
async def embeddings(
request: EmbeddingRequest, raw_request: Request
@@ -855,7 +856,6 @@ async def process_multimodal_request(
return JSONResponse(content=final_response.model_dump(exclude_none=True))
-
async def process_text_request(
handler: MLXLMHandler | MLXVLMHandler,
request: ChatCompletionRequest,
@@ -990,10 +990,12 @@ def format_final_response(
request_id=request_id,
)
+
# =============================================================================
# Responses API Handlers
# =============================================================================
+
def _normalize_responses_item(item: Any) -> dict[str, Any]:
"""Normalize TypedDict/BaseModel response item to a plain dictionary."""
if isinstance(item, dict):
@@ -1014,7 +1016,9 @@ def _serialize_responses_tool_output(output: Any) -> str:
text_parts: list[str] = []
for output_item in output:
normalized = _normalize_responses_item(output_item)
- if normalized.get("type") in {"input_text", "output_text", "text"} and normalized.get("text"):
+ if normalized.get("type") in {"input_text", "output_text", "text"} and normalized.get(
+ "text"
+ ):
text_parts.append(str(normalized["text"]))
if text_parts:
return "\n".join(text_parts)
@@ -1104,9 +1108,7 @@ def _convert_responses_tool_choice(tool_choice: Any) -> Any:
return "auto"
-def convert_responses_request_to_chat_request(
- request: ResponsesRequest
-) -> ChatCompletionRequest:
+def convert_responses_request_to_chat_request(request: ResponsesRequest) -> ChatCompletionRequest:
"""Convert a Responses request into a ChatCompletionRequest with full turn history."""
chat_messages: list[Message] = []
pending_tool_calls: list[ChatCompletionMessageToolCall] = []
@@ -1188,7 +1190,9 @@ def flush_pending_tool_calls() -> None:
if item_type in {"input_text", "text"}:
text = item.get("text")
if text:
- pending_user_parts.append(ChatCompletionContentPartText(type="text", text=str(text)))
+ pending_user_parts.append(
+ ChatCompletionContentPartText(type="text", text=str(text))
+ )
elif item_type in {"input_image", "image_url"} and item.get("image_url"):
pending_user_parts.append(
ChatCompletionContentPartImage(
@@ -1253,10 +1257,9 @@ def flush_pending_tool_calls() -> None:
return ChatCompletionRequest(**chat_request_payload)
+
def format_final_responses_response(
- response: str | dict[str, Any],
- request: ResponsesRequest,
- usage: UsageInfo | None = None
+ response: str | dict[str, Any], request: ResponsesRequest, usage: UsageInfo | None = None
) -> ResponsesResponse:
"""Format the final non-streaming response."""
response_payload: dict[str, Any]
@@ -1391,6 +1394,7 @@ def refine_responses_request(
request.model = Config.TEXT_MODEL
return request
+
async def handle_responses_stream_response(
generator: AsyncGenerator[Any, None],
request: ResponsesRequest,
@@ -1438,9 +1442,7 @@ def _create_base_response(
text_val = None
if request.text:
text_val = (
- request.text.model_dump()
- if hasattr(request.text, "model_dump")
- else request.text
+ request.text.model_dump() if hasattr(request.text, "model_dump") else request.text
)
return {
"id": resp_id,
@@ -1666,7 +1668,12 @@ def _open_message_item() -> str:
{
"id": msg_item_id,
"content": [
- {"annotations": [], "text": full_text, "type": "output_text", "logprobs": None}
+ {
+ "annotations": [],
+ "text": full_text,
+ "type": "output_text",
+ "logprobs": None,
+ }
],
"role": "assistant",
"status": "completed",
@@ -1707,9 +1714,9 @@ def _open_message_item() -> str:
f"data: {json.dumps({'response': final_response_obj, 'sequence_number': _next_seq(), 'type': 'response.completed'})}\n\n"
)
+
async def process_text_responses_request(
- handler: MLXLMHandler,
- request: ResponsesRequest
+ handler: MLXLMHandler, request: ResponsesRequest
) -> ResponsesResponse | StreamingResponse | JSONResponse:
"""Handle text-only Responses API requests."""
refined_request = refine_responses_request(request, handler)
@@ -1764,13 +1771,10 @@ async def process_multimodal_responses_request(
result = await handler.generate_multimodal_response(chat_request)
response_data = result.get("response")
usage = result.get("usage")
- final_response = format_final_responses_response(
- response_data,
- refined_request,
- usage
- )
+ final_response = format_final_responses_response(response_data, refined_request, usage)
return JSONResponse(content=final_response.model_dump(exclude_none=True))
+
@router.post("/v1/responses", response_model=None)
async def responses_endpoint(
request: ResponsesRequest, raw_request: Request
diff --git a/app/core/audio_processor.py b/app/core/audio_processor.py
index 814fb292..f871987d 100644
--- a/app/core/audio_processor.py
+++ b/app/core/audio_processor.py
@@ -1,17 +1,17 @@
-import os
-import gc
import asyncio
-from typing import List
+import gc
+import os
+
from .base_processor import BaseProcessor
class AudioProcessor(BaseProcessor):
"""Audio processor for handling audio files with caching and validation."""
-
+
def __init__(self, max_workers: int = 4, cache_size: int = 1000):
super().__init__(max_workers, cache_size)
# Supported audio formats
- self._supported_formats = {'.mp3', '.wav'}
+ self._supported_formats = {".mp3", ".wav"}
def _get_media_format(self, media_url: str, data: bytes = None) -> str:
"""Determine audio format from URL or data."""
@@ -20,22 +20,22 @@ def _get_media_format(self, media_url: str, data: bytes = None) -> str:
mime_type = media_url.split(";")[0].split(":")[1]
if "mp3" in mime_type or "mpeg" in mime_type:
return "mp3"
- elif "wav" in mime_type:
+ if "wav" in mime_type:
return "wav"
- elif "m4a" in mime_type or "mp4" in mime_type:
+ if "m4a" in mime_type or "mp4" in mime_type:
return "m4a"
- elif "ogg" in mime_type:
+ if "ogg" in mime_type:
return "ogg"
- elif "flac" in mime_type:
+ if "flac" in mime_type:
return "flac"
- elif "aac" in mime_type:
+ if "aac" in mime_type:
return "aac"
else:
# Extract format from file extension
ext = os.path.splitext(media_url.lower())[1]
if ext in self._supported_formats:
return ext[1:] # Remove the dot
-
+
# Default to mp3 if format cannot be determined
return "mp3"
@@ -43,27 +43,27 @@ def _validate_media_data(self, data: bytes) -> bool:
"""Basic validation of audio data."""
if len(data) < 100: # Too small to be a valid audio file
return False
-
+
# Check for common audio file signatures
audio_signatures = [
- b'ID3', # MP3 with ID3 tag
- b'\xff\xfb', # MP3 frame header
- b'\xff\xf3', # MP3 frame header
- b'\xff\xf2', # MP3 frame header
- b'RIFF', # WAV/AVI
- b'OggS', # OGG
- b'fLaC', # FLAC
- b'\x00\x00\x00\x20ftypM4A', # M4A
+ b"ID3", # MP3 with ID3 tag
+ b"\xff\xfb", # MP3 frame header
+ b"\xff\xf3", # MP3 frame header
+ b"\xff\xf2", # MP3 frame header
+ b"RIFF", # WAV/AVI
+ b"OggS", # OGG
+ b"fLaC", # FLAC
+ b"\x00\x00\x00\x20ftypM4A", # M4A
]
-
+
for sig in audio_signatures:
if data.startswith(sig):
return True
-
+
# Check for WAV format (RIFF header might be at different position)
- if b'WAVE' in data[:50]:
+ if b"WAVE" in data[:50]:
return True
-
+
return True # Allow unknown formats to pass through
def _get_timeout(self) -> int:
@@ -76,7 +76,7 @@ def _get_max_file_size(self) -> int:
def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
"""Process audio data and save to cached path."""
- with open(cached_path, 'wb') as f:
+ with open(cached_path, "wb") as f:
f.write(data)
self._cleanup_old_files()
return cached_path
@@ -89,10 +89,10 @@ async def process_audio_url(self, audio_url: str) -> str:
"""Process a single audio URL and return path to cached file."""
return await self._process_single_media(audio_url)
- async def process_audio_urls(self, audio_urls: List[str]) -> List[str]:
+ async def process_audio_urls(self, audio_urls: list[str]) -> list[str]:
"""Process multiple audio URLs and return paths to cached files."""
tasks = [self.process_audio_url(url) for url in audio_urls]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Force garbage collection after batch processing
gc.collect()
- return results
\ No newline at end of file
+ return results
diff --git a/app/core/base_processor.py b/app/core/base_processor.py
index 52833fcf..74b75c8c 100644
--- a/app/core/base_processor.py
+++ b/app/core/base_processor.py
@@ -1,30 +1,31 @@
+from abc import ABC, abstractmethod
import base64
+from concurrent.futures import ThreadPoolExecutor
+import gc
import hashlib
import os
import tempfile
-import aiohttp
import time
-import gc
+from typing import Any
+
+import aiohttp
from loguru import logger
-from typing import Dict, Optional, Any
-from concurrent.futures import ThreadPoolExecutor
-from abc import ABC, abstractmethod
class BaseProcessor(ABC):
"""Base class for media processors with common caching and session management."""
-
+
def __init__(self, max_workers: int = 4, cache_size: int = 1000):
# Use tempfile for macOS-efficient temporary file handling
self.temp_dir = tempfile.TemporaryDirectory()
- self._session: Optional[aiohttp.ClientSession] = None
+ self._session: aiohttp.ClientSession | None = None
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._cache_size = cache_size
self._last_cleanup = time.time()
self._cleanup_interval = 3600 # 1 hour
# Replace lru_cache with manual cache for better control
- self._hash_cache: Dict[str, str] = {}
- self._cache_access_times: Dict[str, float] = {}
+ self._hash_cache: dict[str, str] = {}
+ self._cache_access_times: dict[str, float] = {}
def _get_media_hash(self, media_url: str) -> str:
"""Get hash for media URL with manual caching that can be cleared."""
@@ -32,20 +33,20 @@ def _get_media_hash(self, media_url: str) -> str:
if media_url in self._hash_cache:
self._cache_access_times[media_url] = time.time()
return self._hash_cache[media_url]
-
+
# Generate hash
if media_url.startswith("data:"):
_, encoded = media_url.split(",", 1)
data = base64.b64decode(encoded)
else:
- data = media_url.encode('utf-8')
-
+ data = media_url.encode("utf-8")
+
hash_value = hashlib.md5(data).hexdigest()
-
+
# Add to cache with size management
if len(self._hash_cache) >= self._cache_size:
self._evict_oldest_cache_entries()
-
+
self._hash_cache[media_url] = hash_value
self._cache_access_times[media_url] = time.time()
return hash_value
@@ -54,53 +55,47 @@ def _evict_oldest_cache_entries(self):
"""Remove oldest 20% of cache entries to make room."""
if not self._cache_access_times:
return
-
+
# Sort by access time and remove oldest 20%
sorted_items = sorted(self._cache_access_times.items(), key=lambda x: x[1])
to_remove = len(sorted_items) // 5 # Remove 20%
-
+
for url, _ in sorted_items[:to_remove]:
self._hash_cache.pop(url, None)
self._cache_access_times.pop(url, None)
-
+
# Force garbage collection after cache eviction
gc.collect()
@abstractmethod
def _get_media_format(self, media_url: str, data: bytes = None) -> str:
"""Determine media format from URL or data. Must be implemented by subclasses."""
- pass
@abstractmethod
def _validate_media_data(self, data: bytes) -> bool:
"""Validate media data. Must be implemented by subclasses."""
- pass
@abstractmethod
def _get_timeout(self) -> int:
"""Get timeout for HTTP requests. Must be implemented by subclasses."""
- pass
@abstractmethod
def _get_max_file_size(self) -> int:
"""Get maximum file size in bytes. Must be implemented by subclasses."""
- pass
@abstractmethod
- def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> Dict[str, Any]:
+ def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> dict[str, Any]:
"""Process media data and save to cached path. Must be implemented by subclasses."""
- pass
@abstractmethod
def _get_media_type_name(self) -> str:
"""Get media type name for logging. Must be implemented by subclasses."""
- pass
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self._get_timeout()),
- headers={"User-Agent": "mlx-server-OAI-compat/1.0"}
+ headers={"User-Agent": "mlx-server-OAI-compat/1.0"},
)
return self._session
@@ -118,7 +113,7 @@ def _cleanup_old_files(self):
self._evict_oldest_cache_entries()
gc.collect() # Force garbage collection after cleanup
except Exception as e:
- logger.warning(f"Failed to clean up old {self._get_media_type_name()} files: {str(e)}")
+ logger.warning(f"Failed to clean up old {self._get_media_type_name()} files: {e!s}")
async def _process_single_media(self, media_url: str, **kwargs) -> str:
try:
@@ -132,39 +127,40 @@ async def _process_single_media(self, media_url: str, **kwargs) -> str:
if os.path.exists(media_url):
# Copy local file to cache
- with open(media_url, 'rb') as f:
+ with open(media_url, "rb") as f:
data = f.read()
-
+
if not self._validate_media_data(data):
raise ValueError(f"Invalid {self._get_media_type_name()} file format")
-
+
return self._process_media_data(data, cached_path, **kwargs)
- elif media_url.startswith("data:"):
+ if media_url.startswith("data:"):
_, encoded = media_url.split(",", 1)
estimated_size = len(encoded) * 3 / 4
if estimated_size > self._get_max_file_size():
- raise ValueError(f"Base64-encoded {self._get_media_type_name()} exceeds size limit")
+ raise ValueError(
+ f"Base64-encoded {self._get_media_type_name()} exceeds size limit"
+ )
data = base64.b64decode(encoded)
-
+
if not self._validate_media_data(data):
raise ValueError(f"Invalid {self._get_media_type_name()} file format")
-
+
+ return self._process_media_data(data, cached_path, **kwargs)
+ session = await self._get_session()
+ async with session.get(media_url) as response:
+ response.raise_for_status()
+ data = await response.read()
+
+ if not self._validate_media_data(data):
+ raise ValueError(f"Invalid {self._get_media_type_name()} file format")
+
return self._process_media_data(data, cached_path, **kwargs)
- else:
- session = await self._get_session()
- async with session.get(media_url) as response:
- response.raise_for_status()
- data = await response.read()
-
- if not self._validate_media_data(data):
- raise ValueError(f"Invalid {self._get_media_type_name()} file format")
-
- return self._process_media_data(data, cached_path, **kwargs)
except Exception as e:
- logger.error(f"Failed to process {self._get_media_type_name()}: {str(e)}")
- raise ValueError(f"Failed to process {self._get_media_type_name()}: {str(e)}")
+ logger.error(f"Failed to process {self._get_media_type_name()}: {e!s}")
+ raise ValueError(f"Failed to process {self._get_media_type_name()}: {e!s}")
finally:
gc.collect()
@@ -175,25 +171,25 @@ def clear_cache(self):
gc.collect()
async def cleanup(self):
- if hasattr(self, '_cleaned') and self._cleaned:
+ if hasattr(self, "_cleaned") and self._cleaned:
return
self._cleaned = True
try:
# Clear caches before cleanup
self.clear_cache()
-
+
if self._session and not self._session.closed:
await self._session.close()
except Exception as e:
- logger.warning(f"Exception closing aiohttp session: {str(e)}")
+ logger.warning(f"Exception closing aiohttp session: {e!s}")
try:
self.executor.shutdown(wait=True)
except Exception as e:
- logger.warning(f"Exception shutting down executor: {str(e)}")
+ logger.warning(f"Exception shutting down executor: {e!s}")
try:
self.temp_dir.cleanup()
except Exception as e:
- logger.warning(f"Exception cleaning up temp directory: {str(e)}")
+ logger.warning(f"Exception cleaning up temp directory: {e!s}")
async def __aenter__(self):
return self
@@ -204,4 +200,4 @@ async def __aexit__(self, exc_type, exc, tb):
def __del__(self):
# Async cleanup cannot be reliably performed in __del__
# Please use 'async with Processor()' or call 'await cleanup()' explicitly.
- pass
\ No newline at end of file
+ pass
diff --git a/app/core/handler_process.py b/app/core/handler_process.py
index b38dd3fa..1898d5da 100644
--- a/app/core/handler_process.py
+++ b/app/core/handler_process.py
@@ -439,10 +439,7 @@ async def start(self, queue_config: dict[str, Any]) -> None:
name=f"handler-{self.model_id}",
)
self._process.start()
- logger.info(
- f"Spawned handler process for '{self.model_id}' "
- f"(pid={self._process.pid})"
- )
+ logger.info(f"Spawned handler process for '{self.model_id}' (pid={self._process.pid})")
# Wait for the ready signal.
ready_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
@@ -503,14 +500,10 @@ def _response_reader(self) -> None:
pending = self._pending.get(req_id)
if pending and self._loop:
try:
- future = asyncio.run_coroutine_threadsafe(
- pending.put(response), self._loop
- )
+ future = asyncio.run_coroutine_threadsafe(pending.put(response), self._loop)
future.result(timeout=60)
except concurrent.futures.TimeoutError:
- logger.warning(
- f"Timeout delivering stream chunk for {req_id}"
- )
+ logger.warning(f"Timeout delivering stream chunk for {req_id}")
except Exception:
if self._running:
logger.debug(
diff --git a/app/core/image_processor.py b/app/core/image_processor.py
index c38ccc65..a941a675 100644
--- a/app/core/image_processor.py
+++ b/app/core/image_processor.py
@@ -1,15 +1,16 @@
-import gc
import asyncio
-from PIL import Image
-from loguru import logger
+import gc
from io import BytesIO
-from typing import List
+
+from loguru import logger
+from PIL import Image
+
from .base_processor import BaseProcessor
class ImageProcessor(BaseProcessor):
"""Image processor for handling image files with caching, validation, and processing."""
-
+
def __init__(self, max_workers: int = 4, cache_size: int = 1000):
super().__init__(max_workers, cache_size)
Image.MAX_IMAGE_PIXELS = 100000000 # Limit to 100 megapixels
@@ -23,27 +24,27 @@ def _validate_media_data(self, data: bytes) -> bool:
"""Basic validation of image data."""
if len(data) < 100: # Too small to be a valid image file
return False
-
+
# Check for common image file signatures
image_signatures = [
- b'\xff\xd8\xff', # JPEG
- b'\x89PNG\r\n\x1a\n', # PNG
- b'GIF87a', # GIF87a
- b'GIF89a', # GIF89a
- b'BM', # BMP
- b'II*\x00', # TIFF (little endian)
- b'MM\x00*', # TIFF (big endian)
- b'RIFF', # WebP (part of RIFF)
+ b"\xff\xd8\xff", # JPEG
+ b"\x89PNG\r\n\x1a\n", # PNG
+ b"GIF87a", # GIF87a
+ b"GIF89a", # GIF89a
+ b"BM", # BMP
+ b"II*\x00", # TIFF (little endian)
+ b"MM\x00*", # TIFF (big endian)
+ b"RIFF", # WebP (part of RIFF)
]
-
+
for sig in image_signatures:
if data.startswith(sig):
return True
-
+
# Additional check for WebP
- if data.startswith(b'RIFF') and b'WEBP' in data[:20]:
+ if data.startswith(b"RIFF") and b"WEBP" in data[:20]:
return True
-
+
return False
def _get_timeout(self) -> int:
@@ -58,7 +59,9 @@ def _get_media_type_name(self) -> str:
"""Get media type name for logging."""
return "image"
- def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 448) -> Image.Image:
+ def _resize_image_keep_aspect_ratio(
+ self, image: Image.Image, max_size: int = 448
+ ) -> Image.Image:
width, height = image.size
if width <= max_size and height <= max_size:
return image
@@ -75,15 +78,15 @@ def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 44
return image
def _prepare_image_for_saving(self, image: Image.Image) -> Image.Image:
- if image.mode in ('RGBA', 'LA'):
- background = Image.new('RGB', image.size, (255, 255, 255))
- if image.mode == 'RGBA':
+ if image.mode in ("RGBA", "LA"):
+ background = Image.new("RGB", image.size, (255, 255, 255))
+ if image.mode == "RGBA":
background.paste(image, mask=image.split()[3])
else:
background.paste(image, mask=image.split()[1])
return background
- elif image.mode != 'RGB':
- return image.convert('RGB')
+ if image.mode != "RGB":
+ return image.convert("RGB")
return image
def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
@@ -91,12 +94,12 @@ def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
image = None
resize = kwargs.get("resize", True)
try:
- with Image.open(BytesIO(data), mode='r') as image:
+ with Image.open(BytesIO(data), mode="r") as image:
if resize:
image = self._resize_image_keep_aspect_ratio(image)
image = self._prepare_image_for_saving(image)
- image.save(cached_path, 'PNG', quality=100, optimize=True)
-
+ image.save(cached_path, "PNG", quality=100, optimize=True)
+
self._cleanup_old_files()
return cached_path
finally:
@@ -111,10 +114,10 @@ async def process_image_url(self, image_url: str, resize: bool = True) -> str:
"""Process a single image URL and return path to cached file."""
return await self._process_single_media(image_url, resize=resize)
- async def process_image_urls(self, image_urls: List[str], resize: bool = True) -> List[str]:
+ async def process_image_urls(self, image_urls: list[str], resize: bool = True) -> list[str]:
"""Process multiple image URLs and return paths to cached files."""
tasks = [self.process_image_url(url, resize=resize) for url in image_urls]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Force garbage collection after batch processing
gc.collect()
- return results
\ No newline at end of file
+ return results
diff --git a/app/core/inference_worker.py b/app/core/inference_worker.py
index c76e54cd..31563766 100644
--- a/app/core/inference_worker.py
+++ b/app/core/inference_worker.py
@@ -3,11 +3,11 @@
from __future__ import annotations
import asyncio
+from collections.abc import AsyncGenerator, Callable, Generator
import queue
import threading
-from collections.abc import AsyncGenerator, Generator
from threading import Thread
-from typing import Any, Callable
+from typing import Any
from loguru import logger
@@ -85,9 +85,7 @@ def _record(self, success: bool) -> None:
else:
self._failed += 1
- async def submit(
- self, func: Callable[..., Any], *args: Any, **kwargs: Any
- ) -> Any:
+ async def submit(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Run func on the worker thread; await its result. Raises QueueFull, TimeoutError, or func's exception."""
loop = asyncio.get_running_loop()
future: asyncio.Future[Any] = loop.create_future()
@@ -107,7 +105,7 @@ def _work() -> None:
raise asyncio.QueueFull("Inference queue is full")
try:
return await asyncio.wait_for(future, timeout=self._timeout)
- except asyncio.TimeoutError:
+ except TimeoutError:
raise TimeoutError(f"Inference timed out after {self._timeout}s")
def submit_stream(
diff --git a/app/core/model_registry.py b/app/core/model_registry.py
index 590f3726..005e18ec 100644
--- a/app/core/model_registry.py
+++ b/app/core/model_registry.py
@@ -82,8 +82,7 @@ async def register_model(
self._metadata[model_id] = metadata
logger.info(
- f"Registered model: {model_id} (type={model_type}, "
- f"context_length={context_length})"
+ f"Registered model: {model_id} (type={model_type}, context_length={context_length})"
)
def get_handler(self, model_id: str) -> Any:
@@ -107,8 +106,7 @@ def get_handler(self, model_id: str) -> Any:
if model_id not in self._handlers:
available = ", ".join(sorted(self._handlers.keys())) or "(none)"
raise KeyError(
- f"Model '{model_id}' not found in registry. "
- f"Available models: {available}"
+ f"Model '{model_id}' not found in registry. Available models: {available}"
)
return self._handlers[model_id]
@@ -203,9 +201,7 @@ async def cleanup_all(self) -> None:
logger.info("All models unregistered and cleaned up")
@staticmethod
- async def _cleanup_single_handler(
- model_id: str, handler: Any
- ) -> None:
+ async def _cleanup_single_handler(model_id: str, handler: Any) -> None:
"""Clean up a single handler, logging success or failure.
Parameters
diff --git a/app/core/video_processor.py b/app/core/video_processor.py
index 92d7141d..0123ea41 100644
--- a/app/core/video_processor.py
+++ b/app/core/video_processor.py
@@ -1,18 +1,19 @@
-import os
-import gc
import asyncio
+import gc
+import os
+
from loguru import logger
-from typing import List
+
from .base_processor import BaseProcessor
class VideoProcessor(BaseProcessor):
"""Video processor for handling video files with caching, validation, and processing."""
-
+
def __init__(self, max_workers: int = 4, cache_size: int = 1000):
super().__init__(max_workers, cache_size)
# Supported video formats
- self._supported_formats = {'.mp4', '.avi', '.mov'}
+ self._supported_formats = {".mp4", ".avi", ".mov"}
def _get_media_format(self, media_url: str, data: bytes = None) -> str:
"""Determine video format from URL or data."""
@@ -21,16 +22,16 @@ def _get_media_format(self, media_url: str, data: bytes = None) -> str:
mime_type = media_url.split(";")[0].split(":")[1]
if "mp4" in mime_type:
return "mp4"
- elif "quicktime" in mime_type or "mov" in mime_type:
+ if "quicktime" in mime_type or "mov" in mime_type:
return "mov"
- elif "x-msvideo" in mime_type or "avi" in mime_type:
+ if "x-msvideo" in mime_type or "avi" in mime_type:
return "avi"
else:
# Extract format from file extension
ext = os.path.splitext(media_url.lower())[1]
if ext in self._supported_formats:
return ext[1:] # Remove the dot
-
+
# Default to mp4 if format cannot be determined
return "mp4"
@@ -38,50 +39,45 @@ def _validate_media_data(self, data: bytes) -> bool:
"""Basic validation of video data."""
if len(data) < 100: # Too small to be a valid video file
return False
-
+
# Check for common video file signatures
video_signatures = [
# MP4/M4V/MOV (ISO Base Media File Format)
- (b'\x00\x00\x00\x14ftypisom', 0), # MP4
- (b'\x00\x00\x00\x18ftyp', 0), # MP4/MOV
- (b'\x00\x00\x00\x1cftyp', 0), # MP4/MOV
- (b'\x00\x00\x00\x20ftyp', 0), # MP4/MOV
- (b'ftyp', 4), # MP4/MOV (ftyp at offset 4)
-
+ (b"\x00\x00\x00\x14ftypisom", 0), # MP4
+ (b"\x00\x00\x00\x18ftyp", 0), # MP4/MOV
+ (b"\x00\x00\x00\x1cftyp", 0), # MP4/MOV
+ (b"\x00\x00\x00\x20ftyp", 0), # MP4/MOV
+ (b"ftyp", 4), # MP4/MOV (ftyp at offset 4)
# AVI
- (b'RIFF', 0), # AVI (also check for 'AVI ' at offset 8)
-
+ (b"RIFF", 0), # AVI (also check for 'AVI ' at offset 8)
# WebM/MKV (Matroska)
- (b'\x1a\x45\xdf\xa3', 0), # Matroska/WebM
-
+ (b"\x1a\x45\xdf\xa3", 0), # Matroska/WebM
# FLV
- (b'FLV\x01', 0), # Flash Video
-
+ (b"FLV\x01", 0), # Flash Video
# MPEG
- (b'\x00\x00\x01\xba', 0), # MPEG PS
- (b'\x00\x00\x01\xb3', 0), # MPEG PS
-
+ (b"\x00\x00\x01\xba", 0), # MPEG PS
+ (b"\x00\x00\x01\xb3", 0), # MPEG PS
# QuickTime
- (b'moov', 0), # QuickTime
- (b'mdat', 0), # QuickTime
+ (b"moov", 0), # QuickTime
+ (b"mdat", 0), # QuickTime
]
-
+
for sig, offset in video_signatures:
if len(data) > offset + len(sig):
- if data[offset:offset+len(sig)] == sig:
+ if data[offset : offset + len(sig)] == sig:
# Additional validation for AVI
- if sig == b'RIFF' and len(data) > 12:
- if data[8:12] == b'AVI ':
+ if sig == b"RIFF" and len(data) > 12:
+ if data[8:12] == b"AVI ":
return True
- elif sig == b'RIFF':
+ elif sig == b"RIFF":
continue # Not AVI, might be WAV
else:
return True
-
+
# Check for ftyp box anywhere in first 32 bytes (MP4/MOV)
- if b'ftyp' in data[:32]:
+ if b"ftyp" in data[:32]:
return True
-
+
# Allow unknown formats to pass through for flexibility
return True
@@ -96,14 +92,14 @@ def _get_max_file_size(self) -> int:
def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str:
"""Process video data and save to cached path."""
try:
- with open(cached_path, 'wb') as f:
+ with open(cached_path, "wb") as f:
f.write(data)
-
+
logger.info(f"Saved video to {cached_path} ({len(data)} bytes)")
self._cleanup_old_files()
return cached_path
except Exception as e:
- logger.error(f"Failed to save video data: {str(e)}")
+ logger.error(f"Failed to save video data: {e!s}")
raise
def _get_media_type_name(self) -> str:
@@ -113,27 +109,27 @@ def _get_media_type_name(self) -> str:
async def process_video_url(self, video_url: str) -> str:
"""
Process a single video URL and return path to cached file.
-
+
Supports:
- HTTP/HTTPS URLs (downloads video)
- Local file paths (copies to cache)
- Data URLs (base64 encoded videos)
-
+
Args:
video_url: URL, file path, or data URL of the video
-
+
Returns:
Path to the cached video file in temp directory
"""
return await self._process_single_media(video_url)
- async def process_video_urls(self, video_urls: List[str]) -> List[str]:
+ async def process_video_urls(self, video_urls: list[str]) -> list[str]:
"""
Process multiple video URLs and return paths to cached files.
-
+
Args:
video_urls: List of URLs, file paths, or data URLs of videos
-
+
Returns:
List of paths to cached video files
"""
diff --git a/app/handler/mflux.py b/app/handler/mflux.py
index c54b770b..ef5abbcb 100644
--- a/app/handler/mflux.py
+++ b/app/handler/mflux.py
@@ -1,14 +1,14 @@
import asyncio
import base64
+import gc
+from http import HTTPStatus
import io
+from io import BytesIO
import os
import tempfile
import time
+from typing import Any
import uuid
-import gc
-from io import BytesIO
-from http import HTTPStatus
-from typing import Any, Dict, List, Optional, Union
from fastapi import HTTPException, UploadFile
from loguru import logger
@@ -35,12 +35,18 @@ class MLXFluxHandler:
handler_type: str = "image"
- def __init__(self, model_path: str, max_concurrency: int = 1, quantize: Optional[int] = None,
- config_name: str = "flux-schnell", lora_paths: Optional[List[str]] = None,
- lora_scales: Optional[List[float]] = None):
+ def __init__(
+ self,
+ model_path: str,
+ max_concurrency: int = 1,
+ quantize: int | None = None,
+ config_name: str = "flux-schnell",
+ lora_paths: list[str] | None = None,
+ lora_scales: list[float] | None = None,
+ ):
"""
Initialize the handler with the specified model path.
-
+
Args:
model_path (str): Path to the model directory or model name for Flux.
max_concurrency (int): Maximum number of concurrent model inference tasks.
@@ -54,40 +60,44 @@ def __init__(self, model_path: str, max_concurrency: int = 1, quantize: Optional
self.config_name = config_name
self.lora_paths = lora_paths
self.lora_scales = lora_scales
-
+
self.model = ImageGenerationModel(
- model_path=model_path,
+ model_path=model_path,
quantize=quantize,
config_name=config_name,
lora_paths=lora_paths,
- lora_scales=lora_scales
+ lora_scales=lora_scales,
)
self.model_created = int(time.time()) # Store creation time when model is loaded
-
+
# Dedicated inference thread — keeps the event loop free during
# blocking MLX model computation.
self.inference_worker = InferenceWorker()
- logger.info(f"Initialized MLXFluxHandler with model path: {model_path}, config name: {config_name}")
+ logger.info(
+ f"Initialized MLXFluxHandler with model path: {model_path}, config name: {config_name}"
+ )
if lora_paths:
logger.info(f"Using LoRA adapters: {lora_paths} with scales: {lora_scales}")
- async def get_models(self) -> List[Dict[str, Any]]:
+ async def get_models(self) -> list[dict[str, Any]]:
"""
Get list of available models with their metadata.
"""
try:
- return [{
- "id": self.model_path,
- "object": "model",
- "created": self.model_created,
- "owned_by": "local"
- }]
+ return [
+ {
+ "id": self.model_path,
+ "object": "model",
+ "created": self.model_created,
+ "owned_by": "local",
+ }
+ ]
except Exception as e:
- logger.error(f"Error getting models: {str(e)}")
+ logger.error(f"Error getting models: {e!s}")
return []
- async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> None:
+ async def initialize(self, queue_config: dict[str, Any] | None = None) -> None:
"""Initialize the handler and start the inference worker.
Parameters
@@ -112,29 +122,26 @@ async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> Non
def _parse_image_size(self, size: ImageSize) -> tuple[int, int]:
"""
Parse image size string to width, height tuple.
-
+
Parameters
----------
size : ImageSize
Image size enum value (e.g., "1024x1024").
-
+
Returns
-------
tuple[int, int]
Width and height as integers.
"""
- width, height = map(int, size.value.split('x'))
+ width, height = map(int, size.value.split("x"))
return width, height
-
+
def _build_generation_request_data(
- self,
- request: ImageGenerationRequest,
- width: int,
- height: int
+ self, request: ImageGenerationRequest, width: int, height: int
) -> dict[str, Any]:
"""
Build request data dictionary for image generation.
-
+
Parameters
----------
request : ImageGenerationRequest
@@ -143,7 +150,7 @@ def _build_generation_request_data(
Image width in pixels.
height : int
Image height in pixels.
-
+
Returns
-------
dict[str, Any]
@@ -156,20 +163,22 @@ def _build_generation_request_data(
"seed": request.seed,
"guidance": request.guidance_scale,
"width": width,
- "height": height
+ "height": height,
}
-
- def _build_edit_request_data(self, image_edit_request: ImageEditRequest, temp_file_paths: list[str]) -> dict[str, Any]:
+
+ def _build_edit_request_data(
+ self, image_edit_request: ImageEditRequest, temp_file_paths: list[str]
+ ) -> dict[str, Any]:
"""
Build request data dictionary for image editing.
-
+
Parameters
----------
image_edit_request : ImageEditRequest
The image editing request.
temp_file_paths : list[str]
List of temporary file paths.
-
+
Returns
-------
dict[str, Any]
@@ -182,18 +191,18 @@ def _build_edit_request_data(self, image_edit_request: ImageEditRequest, temp_fi
"steps": image_edit_request.steps,
"seed": image_edit_request.seed,
"guidance": image_edit_request.guidance_scale,
- "image_paths": temp_file_paths
+ "image_paths": temp_file_paths,
}
-
+
def _create_image_response(self, image_result: Image.Image) -> ImageGenerationResponse:
"""
Create image generation response from PIL Image.
-
+
Parameters
----------
image_result : Image.Image
The generated PIL Image.
-
+
Returns
-------
ImageGenerationResponse
@@ -201,14 +210,13 @@ def _create_image_response(self, image_result: Image.Image) -> ImageGenerationRe
"""
image_data_b64 = self._image_to_base64(image_result)
return ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(b64_json=image_data_b64)]
+ created=int(time.time()), data=[ImageData(b64_json=image_data_b64)]
)
def _create_edit_response(self, image_result: Image.Image) -> ImageEditResponse:
"""
Create image editing response from PIL Image.
-
+
Parameters
----------
image_result : Image.Image
@@ -216,19 +224,18 @@ def _create_edit_response(self, image_result: Image.Image) -> ImageEditResponse:
"""
image_data_b64 = self._image_to_base64(image_result)
return ImageEditResponse(
- created=int(time.time()),
- data=[ImageData(b64_json=image_data_b64)]
+ created=int(time.time()), data=[ImageData(b64_json=image_data_b64)]
)
-
+
def _handle_queue_full_error(self, request_id: str) -> None:
"""
Handle queue capacity errors.
-
+
Parameters
----------
request_id : str
The request ID for logging.
-
+
Raises
------
HTTPException
@@ -236,92 +243,90 @@ def _handle_queue_full_error(self, request_id: str) -> None:
"""
logger.error(f"Queue at capacity for request {request_id}")
content = create_error_response(
- "Too many requests. Service is at capacity.",
- "rate_limit_exceeded",
- HTTPStatus.TOO_MANY_REQUESTS
+ "Too many requests. Service is at capacity.",
+ "rate_limit_exceeded",
+ HTTPStatus.TOO_MANY_REQUESTS,
)
raise HTTPException(status_code=429, detail=content)
-
+
def _handle_generation_error(self, request_id: str, error: Exception) -> None:
"""
Handle general generation errors.
-
+
Parameters
----------
request_id : str
The request ID for logging.
error : Exception
The exception that occurred.
-
+
Raises
------
HTTPException
500 error with error details.
"""
- logger.error(f"Error in image generation for request {request_id}: {str(error)}")
+ logger.error(f"Error in image generation for request {request_id}: {error!s}")
content = create_error_response(
- f"Failed to generate image: {str(error)}",
- "server_error",
- HTTPStatus.INTERNAL_SERVER_ERROR
+ f"Failed to generate image: {error!s}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR
)
raise HTTPException(status_code=500, detail=content)
def _handle_edit_error(self, request_id: str, error: Exception) -> None:
"""
Handle general editing errors.
-
+
Parameters
----------
request_id : str
The request ID for logging.
error : Exception
The exception that occurred.
-
+
Raises
------
HTTPException
500 error with error details.
"""
- logger.error(f"Error in image editing for request {request_id}: {str(error)}")
+ logger.error(f"Error in image editing for request {request_id}: {error!s}")
content = create_error_response(
- f"Failed to edit image: {str(error)}",
- "server_error",
- HTTPStatus.INTERNAL_SERVER_ERROR
+ f"Failed to edit image: {error!s}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR
)
raise HTTPException(status_code=500, detail=content)
-
+
def _validate_image_file(self, image: UploadFile, idx: int) -> None:
"""
Validate image file type and size.
-
+
Parameters
----------
image : UploadFile
The uploaded image file to validate.
idx : int
Index of the image (for error messages).
-
+
Raises
------
HTTPException
If validation fails.
"""
- if not image.content_type or image.content_type not in ["image/png", "image/jpeg", "image/jpg"]:
+ if not image.content_type or image.content_type not in [
+ "image/png",
+ "image/jpeg",
+ "image/jpg",
+ ]:
raise HTTPException(
- status_code=400,
- detail=f"Image {idx + 1} must be a PNG, JPEG, or JPG file"
+ status_code=400, detail=f"Image {idx + 1} must be a PNG, JPEG, or JPG file"
)
-
- if hasattr(image, 'size') and image.size and image.size > 10 * 1024 * 1024:
+
+ if hasattr(image, "size") and image.size and image.size > 10 * 1024 * 1024:
raise HTTPException(
- status_code=400,
- detail=f"Image {idx + 1} file size must be less than 10MB"
+ status_code=400, detail=f"Image {idx + 1} file size must be less than 10MB"
)
-
+
async def _upload_to_temp_file(self, image: UploadFile, idx: int, request_id: str) -> str:
"""
Read, process, and save uploaded image to a temporary file.
-
+
Parameters
----------
image : UploadFile
@@ -330,12 +335,12 @@ async def _upload_to_temp_file(self, image: UploadFile, idx: int, request_id: st
Index of the image (for error messages and file naming).
request_id : str
Request ID for file naming.
-
+
Returns
-------
str
Path to the temporary file.
-
+
Raises
------
HTTPException
@@ -345,39 +350,34 @@ async def _upload_to_temp_file(self, image: UploadFile, idx: int, request_id: st
image_data = await image.read()
if not image_data:
raise HTTPException(
- status_code=400,
- detail=f"Empty image file received for image {idx + 1}"
+ status_code=400, detail=f"Empty image file received for image {idx + 1}"
)
-
+
# Load and process image
try:
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
input_image = ImageOps.exif_transpose(input_image)
except Exception as img_error:
- logger.error(f"Failed to process image {idx + 1}: {str(img_error)}")
+ logger.error(f"Failed to process image {idx + 1}: {img_error!s}")
raise HTTPException(
- status_code=400,
- detail=f"Invalid or corrupted image file for image {idx + 1}"
+ status_code=400, detail=f"Invalid or corrupted image file for image {idx + 1}"
)
-
+
# Create and save to temporary file
try:
temp_file = tempfile.NamedTemporaryFile(
- delete=False,
- suffix=".png",
- prefix=f"edit_{request_id}_{idx + 1}_"
+ delete=False, suffix=".png", prefix=f"edit_{request_id}_{idx + 1}_"
)
temp_file_path = temp_file.name
input_image.save(temp_file_path, format="PNG")
temp_file.close()
return temp_file_path
except Exception as temp_error:
- logger.error(f"Failed to create temporary file for image {idx + 1}: {str(temp_error)}")
+ logger.error(f"Failed to create temporary file for image {idx + 1}: {temp_error!s}")
raise HTTPException(
- status_code=500,
- detail=f"Failed to process image {idx + 1} for editing"
+ status_code=500, detail=f"Failed to process image {idx + 1} for editing"
)
-
+
def _cleanup_temp_files(self, temp_file_paths: list[str]) -> None:
"""Clean up temporary files and force garbage collection."""
for temp_file_path in temp_file_paths:
@@ -386,65 +386,65 @@ def _cleanup_temp_files(self, temp_file_paths: list[str]) -> None:
os.unlink(temp_file_path)
logger.debug(f"Cleaned up temporary file: {temp_file_path}")
except OSError as cleanup_error:
- logger.warning(f"Failed to cleanup temporary file {temp_file_path}: {str(cleanup_error)}")
+ logger.warning(
+ f"Failed to cleanup temporary file {temp_file_path}: {cleanup_error!s}"
+ )
gc.collect()
async def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
"""
Generate an image based on the request parameters.
-
+
Parameters
----------
request : ImageGenerationRequest
Request object containing the generation parameters.
-
+
Returns
-------
ImageGenerationResponse
Response containing the generated image data.
-
+
Raises
------
HTTPException
For queue capacity issues or processing failures.
"""
request_id = f"image-{uuid.uuid4()}"
-
+
try:
# Parse image dimensions
width, height = 1024, 1024
if request.size:
width, height = self._parse_image_size(request.size)
-
+
# Build and submit request to the inference thread
request_data = self._build_generation_request_data(request, width, height)
- image_result = await self.inference_worker.submit(
- self._run_inference, request_data
- )
-
+ image_result = await self.inference_worker.submit(self._run_inference, request_data)
+
# Create and return response
return self._create_image_response(image_result)
-
+
except asyncio.QueueFull:
self._handle_queue_full_error(request_id)
-
+
except Exception as e:
self._handle_generation_error(request_id, e)
async def edit_image(self, image_edit_request: ImageEditRequest) -> ImageEditResponse:
"""
Edit an image or multiple images based on the request parameters.
-
+
Parameters
----------
image_edit_request : ImageEditRequest
Request parameters for image editing.
-
+
Returns
-------
ImageEditResponse
Response containing the edited image data.
-
+
Raises
------
HTTPException
@@ -452,23 +452,22 @@ async def edit_image(self, image_edit_request: ImageEditRequest) -> ImageEditRes
"""
# Normalize and validate inputs
images: list[UploadFile] = (
- image_edit_request.image
- if isinstance(image_edit_request.image, list)
+ image_edit_request.image
+ if isinstance(image_edit_request.image, list)
else [image_edit_request.image]
)
-
+
if not images:
raise HTTPException(
- status_code=400,
- detail="At least one image is required for image editing"
+ status_code=400, detail="At least one image is required for image editing"
)
-
+
for idx, image in enumerate(images):
self._validate_image_file(image, idx)
request_id = f"image-edit-{uuid.uuid4()}"
temp_file_paths: list[str] = []
-
+
try:
# Process all images to temporary files
for idx, image in enumerate(images):
@@ -478,31 +477,29 @@ async def edit_image(self, image_edit_request: ImageEditRequest) -> ImageEditRes
# Submit request to the inference thread
request_data = self._build_edit_request_data(image_edit_request, temp_file_paths)
- image_result = await self.inference_worker.submit(
- self._run_inference, request_data
- )
-
+ image_result = await self.inference_worker.submit(self._run_inference, request_data)
+
return self._create_edit_response(image_result)
except asyncio.QueueFull:
self._handle_queue_full_error(request_id)
-
+
except HTTPException:
raise
-
+
except Exception as e:
self._handle_edit_error(request_id, e)
-
+
finally:
self._cleanup_temp_files(temp_file_paths)
-
+
def _image_to_base64(self, image: Image.Image) -> str:
"""
Convert PIL Image to base64 string.
-
+
Args:
image: PIL Image object.
-
+
Returns:
str: Base64 encoded image string.
"""
@@ -510,9 +507,9 @@ def _image_to_base64(self, image: Image.Image) -> str:
image.save(buffer, format="PNG")
buffer.seek(0)
image_data = buffer.getvalue()
- return base64.b64encode(image_data).decode('utf-8')
+ return base64.b64encode(image_data).decode("utf-8")
- def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image:
+ def _run_inference(self, request_data: dict[str, Any]) -> Image.Image:
"""Execute image generation/editing on the inference thread.
This method is submitted to ``InferenceWorker.submit`` and runs
@@ -539,7 +536,7 @@ def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image:
image_path = request_data.get("image_path") # For image editing
guidance = request_data.get("guidance")
image_paths = request_data.get("image_paths", [])
-
+
# Prepare model parameters
model_params = {
"num_inference_steps": steps,
@@ -548,20 +545,22 @@ def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image:
"guidance": guidance,
"image_paths": image_paths,
}
-
+
# Add negative prompt if provided
if negative_prompt:
model_params["negative_prompt"] = negative_prompt
-
+
# Add image path for image editing if provided
if image_path:
model_params["image_path"] = image_path
- logger.info(f"Processing image edit with prompt: {prompt[:50]}... and image: {image_path}")
+ logger.info(
+ f"Processing image edit with prompt: {prompt[:50]}... and image: {image_path}"
+ )
else:
logger.info(f"Generating image with prompt: {prompt[:50]}...")
-
+
# Log all model parameters
- logger.info(f"Model inference configurations:")
+ logger.info("Model inference configurations:")
logger.info(f" - Prompt: {prompt[:100]}...")
logger.info(f" - Negative prompt: {negative_prompt}")
logger.info(f" - Steps: {steps}")
@@ -571,22 +570,16 @@ def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image:
logger.info(f" - Guidance scale: {guidance}")
logger.info(f" - Image path: {image_path}")
logger.info(f" - Model params: {model_params}")
-
+
# Generate image
- image = self.model(
- prompt=prompt,
- seed=seed,
- **model_params
- )
+ image = self.model(prompt=prompt, seed=seed, **model_params)
return image
-
+
except Exception as e:
- logger.error(f"Error processing image generation request: {str(e)}")
+ logger.error(f"Error processing image generation request: {e!s}")
raise
- async def edit_image_from_paths(
- self, edit_data: Dict[str, Any]
- ) -> ImageEditResponse:
+ async def edit_image_from_paths(self, edit_data: dict[str, Any]) -> ImageEditResponse:
"""Edit an image from pre-saved file paths.
This method is used by ``HandlerProcessProxy`` for IPC: the
@@ -620,9 +613,7 @@ async def edit_image_from_paths(
"image_paths": temp_file_paths,
}
- image_result = await self.inference_worker.submit(
- self._run_inference, request_data
- )
+ image_result = await self.inference_worker.submit(self._run_inference, request_data)
return self._create_edit_response(image_result)
except asyncio.QueueFull:
@@ -634,7 +625,7 @@ async def edit_image_from_paths(
finally:
self._cleanup_temp_files(temp_file_paths)
- async def get_queue_stats(self) -> Dict[str, Any]:
+ async def get_queue_stats(self) -> dict[str, Any]:
"""Get statistics from the inference worker.
Returns
@@ -642,24 +633,24 @@ async def get_queue_stats(self) -> Dict[str, Any]:
dict[str, Any]
Dictionary with worker statistics.
"""
- if not hasattr(self, 'inference_worker'):
+ if not hasattr(self, "inference_worker"):
return {"error": "Inference worker not initialized"}
return self.inference_worker.get_stats()
async def cleanup(self) -> None:
"""Clean up resources and stop the inference worker."""
- if hasattr(self, '_cleaned') and self._cleaned:
+ if hasattr(self, "_cleaned") and self._cleaned:
return
self._cleaned = True
try:
logger.info("Cleaning up MLXFluxHandler resources")
- if hasattr(self, 'inference_worker'):
+ if hasattr(self, "inference_worker"):
self.inference_worker.stop()
logger.info("Inference worker stopped successfully")
except Exception as e:
- logger.error(f"Error during MLXFluxHandler cleanup: {str(e)}")
+ logger.error(f"Error during MLXFluxHandler cleanup: {e!s}")
# Force garbage collection
gc.collect()
@@ -671,11 +662,12 @@ def __del__(self):
Note: Async cleanup cannot be reliably performed in __del__.
Please use 'await cleanup()' explicitly.
"""
- if hasattr(self, '_cleaned') and self._cleaned:
+ if hasattr(self, "_cleaned") and self._cleaned:
return
# Set flag to prevent multiple cleanup attempts
self._cleaned = True
+
if __name__ == "__main__":
handler = MLXFluxHandler(model_path="qwen-image", config_name="qwen-image")
- print(handler.model.get_model_info("qwen-image"))
\ No newline at end of file
+ print(handler.model.get_model_info("qwen-image"))
diff --git a/app/handler/mlx_embeddings.py b/app/handler/mlx_embeddings.py
index dd4d9d14..c0565fe2 100644
--- a/app/handler/mlx_embeddings.py
+++ b/app/handler/mlx_embeddings.py
@@ -1,8 +1,8 @@
import gc
+from http import HTTPStatus
import time
+from typing import Any
import uuid
-from http import HTTPStatus
-from typing import Any, Dict, List
from fastapi import HTTPException
from loguru import logger
@@ -12,6 +12,7 @@
from ..schemas.openai import EmbeddingRequest
from ..utils.errors import create_error_response
+
class MLXEmbeddingsHandler:
"""
Handler class for making requests to the underlying MLX embeddings model service.
@@ -23,7 +24,7 @@ class MLXEmbeddingsHandler:
def __init__(self, model_path: str, max_concurrency: int = 1):
"""
Initialize the handler with the specified model path.
-
+
Args:
model_path (str): Path to the embeddings model to load.
max_concurrency (int): Maximum number of concurrent model inference tasks.
@@ -37,23 +38,25 @@ def __init__(self, model_path: str, max_concurrency: int = 1):
self.inference_worker = InferenceWorker()
logger.info(f"Initialized MLXEmbeddingsHandler with model path: {model_path}")
-
- async def get_models(self) -> List[Dict[str, Any]]:
+
+ async def get_models(self) -> list[dict[str, Any]]:
"""
Get list of available models with their metadata.
"""
try:
- return [{
- "id": self.model_path,
- "object": "model",
- "created": self.model_created,
- "owned_by": "local"
- }]
+ return [
+ {
+ "id": self.model_path,
+ "object": "model",
+ "created": self.model_created,
+ "owned_by": "local",
+ }
+ ]
except Exception as e:
- logger.error(f"Error getting models: {str(e)}")
+ logger.error(f"Error getting models: {e!s}")
return []
- async def initialize(self, config: Dict[str, Any]) -> None:
+ async def initialize(self, config: dict[str, Any]) -> None:
"""Initialize the handler and start the inference worker.
Parameters
@@ -72,10 +75,10 @@ async def initialize(self, config: Dict[str, Any]) -> None:
async def generate_embeddings_response(self, request: EmbeddingRequest):
"""
Generate embeddings for a given text input.
-
+
Args:
request: EmbeddingRequest object containing the text input.
-
+
Returns:
List[float]: Embeddings for the input text.
"""
@@ -88,17 +91,21 @@ async def generate_embeddings_response(self, request: EmbeddingRequest):
response = await self.inference_worker.submit(
self.model,
texts=request.input,
- max_length=getattr(request, 'max_length', 512),
+ max_length=getattr(request, "max_length", 512),
)
-
+
return response
except Exception as e:
- logger.error(f"Error in embeddings generation: {str(e)}")
- content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR)
+ logger.error(f"Error in embeddings generation: {e!s}")
+ content = create_error_response(
+ f"Failed to generate embeddings: {e!s}",
+ "server_error",
+ HTTPStatus.INTERNAL_SERVER_ERROR,
+ )
raise HTTPException(status_code=500, detail=content)
- async def get_queue_stats(self) -> Dict[str, Any]:
+ async def get_queue_stats(self) -> dict[str, Any]:
"""Get statistics from the inference worker.
Returns
@@ -118,15 +125,15 @@ async def cleanup(self) -> None:
"""
try:
logger.info("Cleaning up MLXEmbeddingsHandler resources")
- if hasattr(self, 'inference_worker'):
+ if hasattr(self, "inference_worker"):
self.inference_worker.stop()
- if hasattr(self, 'model'):
+ if hasattr(self, "model"):
self.model.cleanup()
# Force garbage collection
gc.collect()
logger.info("MLXEmbeddingsHandler cleanup completed successfully")
except Exception as e:
- logger.error(f"Error during MLXEmbeddingsHandler cleanup: {str(e)}")
+ logger.error(f"Error during MLXEmbeddingsHandler cleanup: {e!s}")
raise
def __del__(self):
@@ -135,7 +142,7 @@ def __del__(self):
Note: Async cleanup cannot be reliably performed in __del__.
Please use 'await cleanup()' explicitly.
"""
- if hasattr(self, '_cleaned') and self._cleaned:
+ if hasattr(self, "_cleaned") and self._cleaned:
return
# Set flag to prevent multiple cleanup attempts
- self._cleaned = True
\ No newline at end of file
+ self._cleaned = True
diff --git a/app/handler/mlx_lm.py b/app/handler/mlx_lm.py
index e5ff24d8..460049e2 100644
--- a/app/handler/mlx_lm.py
+++ b/app/handler/mlx_lm.py
@@ -1,6 +1,6 @@
import asyncio
from collections.abc import AsyncGenerator
-from dataclasses import dataclass, field
+from dataclasses import dataclass
import gc
from http import HTTPStatus
import time
@@ -223,9 +223,7 @@ def refine_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]
return [{k: v for k, v in message.items() if v is not None} for message in messages]
- async def _build_inference_context(
- self, request: ChatCompletionRequest
- ) -> "_InferenceContext":
+ async def _build_inference_context(self, request: ChatCompletionRequest) -> "_InferenceContext":
"""Build the common inference context shared by stream and non-stream paths.
Handles: request parsing, message refinement, prompt encoding, KV cache
@@ -890,7 +888,9 @@ async def _prepare_text_request(
try:
# Extract only the fields consumed by MLX_LM.__call__ instead of
# serializing the entire Pydantic model with model_dump().
- chat_template_kwargs = request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {}
+ chat_template_kwargs = (
+ request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {}
+ )
if request.tools:
tools = [t.model_dump() for t in request.tools]
diff --git a/app/handler/mlx_vlm.py b/app/handler/mlx_vlm.py
index 7f2a4cb7..3c287d31 100644
--- a/app/handler/mlx_vlm.py
+++ b/app/handler/mlx_vlm.py
@@ -213,7 +213,9 @@ async def generate_multimodal_stream(self, request: ChatCompletionRequest):
"""
try:
- input_prompt, model_params, parsers_result = await self._build_inference_context(request)
+ input_prompt, model_params, parsers_result = await self._build_inference_context(
+ request
+ )
if self.debug:
log_debug_model_dispatch(
@@ -478,7 +480,9 @@ async def generate_multimodal_response(self, request: ChatCompletionRequest):
str: Complete response.
"""
try:
- input_prompt, model_params, parsers_result = await self._build_inference_context(request)
+ input_prompt, model_params, parsers_result = await self._build_inference_context(
+ request
+ )
if self.debug:
log_debug_model_dispatch(
@@ -783,7 +787,9 @@ async def _prepare_multimodal_request(
# Extract only the fields consumed downstream instead of serializing
# the entire Pydantic model with model_dump().
- chat_template_kwargs = request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {}
+ chat_template_kwargs = (
+ request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {}
+ )
if request.tools:
tools = [t.model_dump() for t in request.tools]
diff --git a/app/handler/mlx_whisper.py b/app/handler/mlx_whisper.py
index e62d1b7c..8501bf8c 100644
--- a/app/handler/mlx_whisper.py
+++ b/app/handler/mlx_whisper.py
@@ -1,11 +1,12 @@
+from collections.abc import AsyncGenerator
import gc
+from http import HTTPStatus
import json
import os
import tempfile
import time
+from typing import Any
import uuid
-from http import HTTPStatus
-from typing import Any, AsyncGenerator, Dict, List, Optional
from fastapi import HTTPException
from loguru import logger
@@ -23,6 +24,7 @@
)
from ..utils.errors import create_error_response
+
class MLXWhisperHandler:
"""
Handler class for making requests to the underlying MLX Whisper model service.
@@ -34,7 +36,7 @@ class MLXWhisperHandler:
def __init__(self, model_path: str, max_concurrency: int = 1):
"""
Initialize the handler with the specified model path.
-
+
Args:
model_path (str): Path to the model directory.
max_concurrency (int): Maximum number of concurrent model inference tasks.
@@ -48,23 +50,25 @@ def __init__(self, model_path: str, max_concurrency: int = 1):
self.inference_worker = InferenceWorker()
logger.info(f"Initialized MLXWhisperHandler with model path: {model_path}")
-
- async def get_models(self) -> List[Dict[str, Any]]:
+
+ async def get_models(self) -> list[dict[str, Any]]:
"""
Get list of available models with their metadata.
"""
try:
- return [{
- "id": self.model_path,
- "object": "model",
- "created": self.model_created,
- "owned_by": "local"
- }]
+ return [
+ {
+ "id": self.model_path,
+ "object": "model",
+ "created": self.model_created,
+ "owned_by": "local",
+ }
+ ]
except Exception as e:
- logger.error(f"Error getting models: {str(e)}")
+ logger.error(f"Error getting models: {e!s}")
return []
-
- async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> None:
+
+ async def initialize(self, queue_config: dict[str, Any] | None = None) -> None:
"""Initialize the handler and start the inference worker.
Parameters
@@ -85,13 +89,15 @@ async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> Non
self.inference_worker.start()
logger.info("Initialized MLXWhisperHandler and started inference worker")
- async def generate_transcription_response(self, request: TranscriptionRequest) -> TranscriptionResponse:
+ async def generate_transcription_response(
+ self, request: TranscriptionRequest
+ ) -> TranscriptionResponse:
"""
Generate a transcription response for the given request.
"""
request_id = f"transcription-{uuid.uuid4()}"
temp_file_path = None
-
+
try:
request_data = await self._prepare_transcription_request(request)
temp_file_path = request_data.get("audio_path")
@@ -106,15 +112,13 @@ async def generate_transcription_response(self, request: TranscriptionRequest) -
response_data = TranscriptionResponse(
text=response["text"],
usage=TranscriptionUsageAudio(
- type="duration",
- seconds=int(calculate_audio_duration(temp_file_path))
- )
+ type="duration", seconds=int(calculate_audio_duration(temp_file_path))
+ ),
)
if request.response_format == TranscriptionResponseFormat.JSON:
return response_data
- else:
- # dump to string for text response
- return json.dumps(response_data.model_dump())
+ # dump to string for text response
+ return json.dumps(response_data.model_dump())
finally:
# Clean up temporary file
if temp_file_path and os.path.exists(temp_file_path):
@@ -122,17 +126,15 @@ async def generate_transcription_response(self, request: TranscriptionRequest) -
os.unlink(temp_file_path)
logger.debug(f"Cleaned up temporary file: {temp_file_path}")
except Exception as e:
- logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}")
+ logger.warning(f"Failed to clean up temporary file {temp_file_path}: {e!s}")
async def generate_transcription_stream_from_data(
- self,
- request_data: Dict[str, Any],
- response_format: TranscriptionResponseFormat
+ self, request_data: dict[str, Any], response_format: TranscriptionResponseFormat
) -> AsyncGenerator[str, None]:
"""
Generate a transcription stream from prepared request data.
Yields SSE-formatted chunks with timing information.
-
+
Args:
request_data: Prepared request data with audio_path already saved
response_format: The response format (json or text)
@@ -140,7 +142,7 @@ async def generate_transcription_stream_from_data(
request_id = f"transcription-{uuid.uuid4()}"
created_time = int(time.time())
temp_file_path = request_data.get("audio_path")
-
+
try:
# Set stream mode and submit to inference thread
request_data["stream"] = True
@@ -164,17 +166,14 @@ async def generate_transcription_stream_from_data(
model=self.model_path,
choices=[
TranscriptionResponseStreamChoice(
- delta=Delta(
- content=chunk.get("text", "")
- ),
- finish_reason=None
+ delta=Delta(content=chunk.get("text", "")), finish_reason=None
)
- ]
+ ],
)
-
+
# Yield as SSE format
yield f"data: {stream_response.model_dump_json()}\n\n"
-
+
# Send final chunk with finish_reason
final_response = TranscriptionResponseStream(
id=request_id,
@@ -182,17 +181,14 @@ async def generate_transcription_stream_from_data(
created=created_time,
model=self.model_path,
choices=[
- TranscriptionResponseStreamChoice(
- delta=Delta(content=""),
- finish_reason="stop"
- )
- ]
+ TranscriptionResponseStreamChoice(delta=Delta(content=""), finish_reason="stop")
+ ],
)
yield f"data: {final_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
-
+
except Exception as e:
- logger.error(f"Error during transcription streaming: {str(e)}")
+ logger.error(f"Error during transcription streaming: {e!s}")
raise
finally:
# Clean up temporary file
@@ -201,15 +197,15 @@ async def generate_transcription_stream_from_data(
os.unlink(temp_file_path)
logger.debug(f"Cleaned up temporary file: {temp_file_path}")
except Exception as e:
- logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}")
+ logger.warning(f"Failed to clean up temporary file {temp_file_path}: {e!s}")
async def _save_uploaded_file(self, file) -> str:
"""
Save the uploaded file to a temporary location.
-
+
Args:
file: The uploaded file object.
-
+
Returns:
str: Path to the temporary file.
"""
@@ -218,39 +214,35 @@ async def _save_uploaded_file(self, file) -> str:
file_extension = os.path.splitext(file.filename)[1] if file.filename else ".wav"
print("file_extension", file_extension)
-
+
# Read file content first (this can only be done once with FastAPI uploads)
content = await file.read()
-
+
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file:
# Write the file contents
temp_file.write(content)
temp_path = temp_file.name
-
+
logger.debug(f"Saved uploaded file to temporary location: {temp_path}")
return temp_path
-
+
except Exception as e:
- logger.error(f"Error saving uploaded file: {str(e)}")
+ logger.error(f"Error saving uploaded file: {e!s}")
raise
- async def _prepare_transcription_request(
- self,
- request: TranscriptionRequest
- ) -> Dict[str, Any]:
+ async def _prepare_transcription_request(self, request: TranscriptionRequest) -> dict[str, Any]:
"""
Prepare a transcription request by parsing model parameters.
-
+
Args:
request: TranscriptionRequest object.
audio_path: Path to the audio file.
-
+
Returns:
Dict containing the request data ready for the model.
"""
try:
-
file = request.file
file_path = await self._save_uploaded_file(file)
@@ -258,41 +250,37 @@ async def _prepare_transcription_request(
"audio_path": file_path,
"verbose": False,
}
-
+
# Add optional parameters if provided
if request.temperature is not None:
request_data["temperature"] = request.temperature
-
+
if request.language is not None:
request_data["language"] = request.language
-
+
if request.prompt is not None:
request_data["initial_prompt"] = request.prompt
-
+
# Map additional parameters if they exist
decode_options = {}
if request.language is not None:
decode_options["language"] = request.language
-
+
# Add decode options to request data
request_data.update(decode_options)
-
+
logger.debug(f"Prepared transcription request: {request_data}")
-
+
return request_data
-
+
except Exception as e:
- logger.error(f"Failed to prepare transcription request: {str(e)}")
+ logger.error(f"Failed to prepare transcription request: {e!s}")
content = create_error_response(
- f"Failed to process request: {str(e)}",
- "bad_request",
- HTTPStatus.BAD_REQUEST
+ f"Failed to process request: {e!s}", "bad_request", HTTPStatus.BAD_REQUEST
)
raise HTTPException(status_code=400, detail=content)
- async def transcribe_from_data(
- self, request_data: Dict[str, Any]
- ) -> TranscriptionResponse:
+ async def transcribe_from_data(self, request_data: dict[str, Any]) -> TranscriptionResponse:
"""Run transcription from pre-processed request data.
This method is used by ``HandlerProcessProxy`` for IPC: the
@@ -330,13 +318,11 @@ async def transcribe_from_data(
try:
os.unlink(temp_file_path)
except Exception as e:
- logger.warning(
- f"Failed to clean up temp file {temp_file_path}: {e}"
- )
+ logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")
async def transcribe_stream_from_data(
self,
- request_data: Dict[str, Any],
+ request_data: dict[str, Any],
) -> AsyncGenerator[str, None]:
"""Run streaming transcription from pre-processed request data.
@@ -408,11 +394,9 @@ async def transcribe_stream_from_data(
try:
os.unlink(temp_file_path)
except Exception as e:
- logger.warning(
- f"Failed to clean up temp file {temp_file_path}: {e}"
- )
+ logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}")
- async def get_queue_stats(self) -> Dict[str, Any]:
+ async def get_queue_stats(self) -> dict[str, Any]:
"""Get statistics from the inference worker.
Returns
@@ -432,12 +416,11 @@ async def cleanup(self) -> None:
"""
try:
logger.info("Cleaning up MLXWhisperHandler resources")
- if hasattr(self, 'inference_worker'):
+ if hasattr(self, "inference_worker"):
self.inference_worker.stop()
# Force garbage collection
gc.collect()
logger.info("MLXWhisperHandler cleanup completed successfully")
except Exception as e:
- logger.error(f"Error during MLXWhisperHandler cleanup: {str(e)}")
+ logger.error(f"Error during MLXWhisperHandler cleanup: {e!s}")
raise
-
diff --git a/app/message_converters/abstract_converter.py b/app/message_converters/abstract_converter.py
index 542d1acf..29994a0f 100644
--- a/app/message_converters/abstract_converter.py
+++ b/app/message_converters/abstract_converter.py
@@ -2,9 +2,10 @@
from typing import Any
+
class AbstractMessageConverter:
"""Abstract message converter class that should not be used directly.
-
+
Provided properties and methods should be used in derived classes to convert
messages to be compatible with specific model chat templates.
"""
@@ -13,4 +14,4 @@ def convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any
"""Convert messages to be compatible with specific model chat templates"""
raise NotImplementedError(
"AbstractMessageConverter.convert_messages has not been implemented!"
- )
\ No newline at end of file
+ )
diff --git a/app/message_converters/glm4_moe.py b/app/message_converters/glm4_moe.py
index 4c1b5b27..b1052df6 100644
--- a/app/message_converters/glm4_moe.py
+++ b/app/message_converters/glm4_moe.py
@@ -5,17 +5,18 @@
from .abstract_converter import AbstractMessageConverter
+
class GLM4MoEMessageConverter(AbstractMessageConverter):
"""GLM4 MoE-specific message format converter"""
def convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert message format to be compatible with GLM4 MoE chat templates.
-
+
Parameters
----------
messages : list[dict[str, Any]]
List of messages in OpenAI API format.
-
+
Returns
-------
list[dict[str, Any]]
@@ -32,12 +33,12 @@ def convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any
def _convert_single_message(self, message: dict[str, Any]) -> dict[str, Any]:
"""Convert a single message.
-
+
Parameters
----------
message : dict[str, Any]
Single message to convert.
-
+
Returns
-------
dict[str, Any]
@@ -53,7 +54,7 @@ def _convert_single_message(self, message: dict[str, Any]) -> dict[str, Any]:
def _convert_tool_calls(self, tool_calls: list[dict[str, Any]]) -> None:
"""Convert arguments format in tool calls.
-
+
Parameters
----------
tool_calls : list[dict[str, Any]]
@@ -69,12 +70,12 @@ def _convert_tool_calls(self, tool_calls: list[dict[str, Any]]) -> None:
def _parse_arguments_string(self, arguments_str: str) -> Any:
"""Parse GLM4 MoE-specific argument string format.
-
+
Parameters
----------
arguments_str : str
Arguments in string format.
-
+
Returns
-------
Any
@@ -83,4 +84,4 @@ def _parse_arguments_string(self, arguments_str: str) -> Any:
try:
return json.loads(arguments_str)
except json.JSONDecodeError:
- return arguments_str
\ No newline at end of file
+ return arguments_str
diff --git a/app/models/mflux.py b/app/models/mflux.py
index 68c7b96d..30a3e74d 100644
--- a/app/models/mflux.py
+++ b/app/models/mflux.py
@@ -1,21 +1,21 @@
from __future__ import annotations
-from loguru import logger
-import inspect
-from PIL import Image
from abc import ABC, abstractmethod
-from typing import Any, Callable, Optional
+from collections.abc import Callable
+import inspect
+from typing import Any
+from loguru import logger
from mflux.models.common.config import ModelConfig
-from mflux.models.z_image.variants import ZImageTurbo
from mflux.models.fibo.variants.txt2img.fibo import FIBO
-from mflux.models.flux.variants.txt2img.flux import Flux1
-from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext
-from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
-from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
+from mflux.models.flux.variants.txt2img.flux import Flux1
from mflux.models.flux2.variants.edit.flux2_klein_edit import Flux2KleinEdit
-
+from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein
+from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
+from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
+from mflux.models.z_image.variants import ZImageTurbo
+from PIL import Image
# -----------------------------------------------------------------------------
# Exceptions
@@ -24,22 +24,18 @@
class ImageModelError(Exception):
"""Base exception for image generation model errors."""
- pass
class ModelLoadError(ImageModelError):
"""Raised when model loading fails."""
- pass
class ModelGenerationError(ImageModelError):
"""Raised when image generation fails."""
- pass
class InvalidConfigurationError(ImageModelError):
"""Raised when configuration is invalid."""
- pass
# -----------------------------------------------------------------------------
@@ -48,8 +44,8 @@ class InvalidConfigurationError(ImageModelError):
def _lora_validate(
- lora_paths: Optional[list[str]] | None,
- lora_scales: Optional[list[float]] | None,
+ lora_paths: list[str] | None,
+ lora_scales: list[float] | None,
) -> None:
if (lora_paths is None) != (lora_scales is None):
raise InvalidConfigurationError(
@@ -69,9 +65,9 @@ def __init__(
self,
model_type: str,
model_config: ModelConfig,
- quantize: Optional[int] = None,
- lora_paths: Optional[list[str]] = None,
- lora_scales: Optional[list[float]] = None,
+ quantize: int | None = None,
+ lora_paths: list[str] | None = None,
+ lora_scales: list[float] | None = None,
) -> None:
if quantize is not None and quantize not in (4, 8, 16):
raise InvalidConfigurationError(
@@ -88,9 +84,9 @@ def __init__(
def from_name(
cls,
config_name: str,
- quantize: Optional[int] = None,
- lora_paths: Optional[list[str]] = None,
- lora_scales: Optional[list[float]] = None,
+ quantize: int | None = None,
+ lora_paths: list[str] | None = None,
+ lora_scales: list[float] | None = None,
) -> ModelConfiguration:
if config_name not in _CONFIG_REGISTRY:
available = ", ".join(_CONFIG_REGISTRY.keys())
@@ -147,15 +143,12 @@ def is_loaded(self) -> bool:
@abstractmethod
def _load_model(self) -> None:
"""Load the specific model implementation."""
- pass
def _generate_image(self, prompt: str, seed: int = 42, **kwargs: Any) -> Image.Image:
sig = inspect.signature(self._model.generate_image)
valid = set(sig.parameters.keys())
filtered = {k: v for k, v in kwargs.items() if k in valid}
- result = self._model.generate_image(
- prompt=prompt, seed=seed, **filtered
- )
+ result = self._model.generate_image(prompt=prompt, seed=seed, **filtered)
return result.image
def __call__(self, prompt: str, seed: int = 42, **kwargs: Any) -> Image.Image:
@@ -262,9 +255,9 @@ def __init__(
self,
model_path: str,
config_name: str,
- quantize: Optional[int] = None,
- lora_paths: Optional[list[str]] = None,
- lora_scales: Optional[list[float]] = None,
+ quantize: int | None = None,
+ lora_paths: list[str] | None = None,
+ lora_scales: list[float] | None = None,
) -> None:
self.model_path = model_path
self.config_name = config_name
@@ -332,4 +325,4 @@ def is_loaded(self) -> bool:
guidance=1.0,
image_paths=[image_path],
)
- image.save("examples/result.png")
\ No newline at end of file
+ image.save("examples/result.png")
diff --git a/app/models/mlx_embeddings.py b/app/models/mlx_embeddings.py
index 84454863..70620bd9 100644
--- a/app/models/mlx_embeddings.py
+++ b/app/models/mlx_embeddings.py
@@ -1,12 +1,13 @@
import gc
+
import mlx.core as mx
from mlx_embeddings.utils import load
-from typing import List, Optional
+
class MLX_Embeddings:
"""
A wrapper class for MLX Embeddings that handles memory management to prevent leaks.
-
+
This class provides a unified interface for generating embeddings from text inputs,
with proper cleanup of MLX arrays and memory management.
"""
@@ -14,26 +15,26 @@ class MLX_Embeddings:
def __init__(self, model_path: str):
"""
Initialize the MLX_Embeddings model.
-
+
Args:
model_name (str): Name of the model to load.
-
+
Raises:
ValueError: If model loading fails.
"""
try:
self.model, self.tokenizer = load(model_path)
except Exception as e:
- raise ValueError(f"Error loading model: {str(e)}")
-
- def _get_embeddings(self, texts: List[str], max_length: int = 512) -> mx.array:
+ raise ValueError(f"Error loading model: {e!s}")
+
+ def _get_embeddings(self, texts: list[str], max_length: int = 512) -> mx.array:
"""
Get embeddings for a list of texts with proper memory management.
-
+
Args:
texts: List of text inputs
max_length: Maximum sequence length for tokenization
-
+
Returns:
MLX array of embeddings
"""
@@ -42,25 +43,20 @@ def _get_embeddings(self, texts: List[str], max_length: int = 512) -> mx.array:
try:
# Tokenize inputs
tokenizer = getattr(self.tokenizer, "_tokenizer", self.tokenizer)
-
+
inputs = tokenizer(
- texts,
- return_tensors="np",
- padding=True,
- truncation=True,
- max_length=max_length
+ texts, return_tensors="np", padding=True, truncation=True, max_length=max_length
)
-
+
# Generate embeddings
outputs = self.model(
- mx.array(inputs["input_ids"]),
- attention_mask=mx.array(inputs["attention_mask"])
+ mx.array(inputs["input_ids"]), attention_mask=mx.array(inputs["attention_mask"])
).text_embeds
-
+
# Return a copy to ensure the result persists after cleanup
return mx.array(outputs)
-
- except Exception as e:
+
+ except Exception:
# Clean up on error
self._cleanup_arrays(inputs, outputs)
raise
@@ -75,25 +71,25 @@ def _cleanup_arrays(self, *arrays):
try:
if isinstance(array, dict):
for key, value in array.items():
- if hasattr(value, 'nbytes'):
+ if hasattr(value, "nbytes"):
del value
- elif hasattr(array, 'nbytes'):
+ elif hasattr(array, "nbytes"):
del array
except:
pass
-
+
# Clear MLX cache and force garbage collection
mx.clear_cache()
gc.collect()
- def __call__(self, texts: List[str], max_length: int = 512) -> List[List[float]]:
+ def __call__(self, texts: list[str], max_length: int = 512) -> list[list[float]]:
"""
Generate embeddings for a list of texts.
-
+
Args:
texts: List of text inputs
max_length: Maximum sequence length for tokenization
-
+
Returns:
List of embedding vectors as float lists
"""
@@ -106,7 +102,7 @@ def __call__(self, texts: List[str], max_length: int = 512) -> List[List[float]]
mx.clear_cache()
gc.collect()
return result
- except Exception as e:
+ except Exception:
# Clean up on error
mx.clear_cache()
gc.collect()
@@ -116,15 +112,15 @@ def cleanup(self):
"""Explicitly cleanup resources."""
try:
# Clear any cached model outputs
- if hasattr(self, 'model'):
+ if hasattr(self, "model"):
del self.model
- if hasattr(self, 'tokenizer'):
+ if hasattr(self, "tokenizer"):
del self.tokenizer
-
+
# Clear MLX cache and force garbage collection
mx.clear_cache()
gc.collect()
- except Exception as e:
+ except Exception:
# Log cleanup errors but don't raise
pass
@@ -132,6 +128,7 @@ def __del__(self):
"""Destructor to ensure cleanup on object deletion."""
self.cleanup()
+
if __name__ == "__main__":
model_path = "mlx-community/all-MiniLM-L6-v2-4bit"
model = MLX_Embeddings(model_path)
@@ -141,4 +138,4 @@ def __del__(self):
print(f"Generated embeddings shape: {len(embeddings)} x {len(embeddings[0])}")
finally:
# Explicit cleanup
- model.cleanup()
\ No newline at end of file
+ model.cleanup()
diff --git a/app/models/mlx_lm.py b/app/models/mlx_lm.py
index 82d3595d..063231f1 100644
--- a/app/models/mlx_lm.py
+++ b/app/models/mlx_lm.py
@@ -1,18 +1,18 @@
+from collections.abc import Generator
+from dataclasses import dataclass
import os
-import mlx.core as mx
+from typing import Any
+
from loguru import logger
+import mlx.core as mx
+from mlx_lm.generate import GenerationResponse, stream_generate
+from mlx_lm.models.cache import make_prompt_cache
+from mlx_lm.sample_utils import make_logits_processors, make_sampler
from mlx_lm.utils import load
-from mlx_lm.generate import (
- stream_generate
-)
-from dataclasses import dataclass
-from mlx_lm.generate import GenerationResponse
from outlines.processors import JSONLogitsProcessor
-from mlx_lm.models.cache import make_prompt_cache
-from mlx_lm.sample_utils import make_sampler, make_logits_processors
-from ..utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer
+
from ..utils.debug_logging import log_debug_chat_template
-from typing import List, Dict, Union, Generator, Any
+from ..utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer
DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7"))
DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
@@ -24,6 +24,7 @@
DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "1000000"))
DEFAULT_REPETITION_CONTEXT_SIZE = int(os.getenv("DEFAULT_REPETITION_CONTEXT_SIZE", "20"))
+
@dataclass
class CompletionResponse:
"""
@@ -40,24 +41,36 @@ class CompletionResponse:
"""
text: str = None
- tokens: List[int] = None
+ tokens: list[int] = None
peak_memory: float = None
generation_tps: float = None
prompt_tps: float = None
prompt_tokens: int = None
generation_tokens: int = None
+
class MLX_LM:
"""
A wrapper class for MLX Language Model that handles both streaming and non-streaming inference.
-
+
This class provides a unified interface for generating text responses from text prompts,
supporting both streaming and non-streaming modes.
"""
- def __init__(self, model_path: str, draft_model_path: str = None, num_draft_tokens: int = 2, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False):
+ def __init__(
+ self,
+ model_path: str,
+ draft_model_path: str = None,
+ num_draft_tokens: int = 2,
+ context_length: int | None = None,
+ trust_remote_code: bool = False,
+ chat_template_file: str = None,
+ debug: bool = False,
+ ):
try:
- self.model, self.tokenizer = load(model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code})
+ self.model, self.tokenizer = load(
+ model_path, lazy=False, tokenizer_config={"trust_remote_code": trust_remote_code}
+ )
self.context_length = context_length
self.draft_model = None
self.draft_tokenizer = None
@@ -72,21 +85,29 @@ def __init__(self, model_path: str, draft_model_path: str = None, num_draft_toke
if chat_template_file:
if not os.path.exists(chat_template_file):
raise ValueError(f"Chat template file {chat_template_file} does not exist")
- with open(chat_template_file, "r") as f:
+ with open(chat_template_file) as f:
template_content = f.read()
self.tokenizer.chat_template = template_content
if self.debug:
- log_debug_chat_template(chat_template_file=chat_template_file, template_content=template_content)
+ log_debug_chat_template(
+ chat_template_file=chat_template_file, template_content=template_content
+ )
except Exception as e:
- raise ValueError(f"Error loading model: {str(e)}")
+ raise ValueError(f"Error loading model: {e!s}")
def _load_draft_model(self, draft_model_path: str, trust_remote_code: bool) -> None:
try:
- self.draft_model, self.draft_tokenizer = load(draft_model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code})
- self.context_length = None # speculative decoding does not support context length, should be set to None
+ self.draft_model, self.draft_tokenizer = load(
+ draft_model_path,
+ lazy=False,
+ tokenizer_config={"trust_remote_code": trust_remote_code},
+ )
+ self.context_length = (
+ None # speculative decoding does not support context length, should be set to None
+ )
self._validate_draft_tokenizer()
except Exception as e:
- raise ValueError(f"Error loading draft model: {str(e)}")
+ raise ValueError(f"Error loading draft model: {e!s}")
def _validate_draft_tokenizer(self) -> None:
if self.draft_tokenizer.vocab_size != self.tokenizer.vocab_size:
@@ -95,7 +116,7 @@ def _validate_draft_tokenizer(self) -> None:
"Speculative decoding may not work as expected."
)
- def create_prompt_cache(self) -> List[Any]:
+ def create_prompt_cache(self) -> list[Any]:
cache = make_prompt_cache(self.model, max_kv_size=self.context_length)
if self.draft_model:
cache += make_prompt_cache(self.draft_model, max_kv_size=self.context_length)
@@ -104,7 +125,9 @@ def create_prompt_cache(self) -> List[Any]:
def get_model_type(self) -> str:
return self.model_type
- def create_input_prompt(self, messages: List[Dict[str, str]], chat_template_kwargs: Dict[str, Any]) -> str:
+ def create_input_prompt(
+ self, messages: list[dict[str, str]], chat_template_kwargs: dict[str, Any]
+ ) -> str:
use_partial = chat_template_kwargs.pop("_partial_mode", False)
return self.tokenizer.apply_chat_template(
@@ -115,16 +138,12 @@ def create_input_prompt(self, messages: List[Dict[str, str]], chat_template_kwar
**chat_template_kwargs,
)
- def encode_prompt(self, input_prompt: str) -> List[int]:
+ def encode_prompt(self, input_prompt: str) -> list[int]:
return self.tokenizer.encode(input_prompt)
def __call__(
- self,
- input_ids: List[int],
- prompt_cache: List[Any] = None,
- stream: bool = False,
- **kwargs
- ) -> Union[CompletionResponse, Generator[GenerationResponse, None, None]]:
+ self, input_ids: list[int], prompt_cache: list[Any] = None, stream: bool = False, **kwargs
+ ) -> CompletionResponse | Generator[GenerationResponse, None, None]:
"""
Generate text response from the model.
@@ -137,6 +156,7 @@ def __call__(
- seed: Random seed (default: 0)
- max_tokens: Maximum number of tokens to generate (default: 256)
"""
+
# Set default parameters if not provided (use 'is not None' to preserve valid 0 values)
def _get(key, default):
v = kwargs.get(key)
@@ -172,16 +192,14 @@ def _get(key, default):
logits_processors = make_logits_processors(
logit_bias=logit_bias,
repetition_penalty=repetition_penalty,
- repetition_context_size=repetition_context_size
+ repetition_context_size=repetition_context_size,
)
json_schema = kwargs.get("schema")
if json_schema:
logits_processors.append(
JSONLogitsProcessor(
- schema=json_schema,
- tokenizer=self.outlines_tokenizer,
- tensor_library_name="mlx"
+ schema=json_schema, tokenizer=self.outlines_tokenizer, tensor_library_name="mlx"
)
)
@@ -189,12 +207,10 @@ def _get(key, default):
# None or negative values (e.g., -1) result in non-deterministic generation
if seed and seed >= 0:
mx.random.seed(seed)
-
+
prompt_progress_callback = kwargs.get("prompt_progress_callback")
-
- sampler = make_sampler(
- **sampler_kwargs
- )
+
+ sampler = make_sampler(**sampler_kwargs)
stream_response = stream_generate(
self.model,
@@ -206,7 +222,7 @@ def _get(key, default):
num_draft_tokens=self.num_draft_tokens,
prompt_cache=prompt_cache,
logits_processors=logits_processors,
- prompt_progress_callback=prompt_progress_callback
+ prompt_progress_callback=prompt_progress_callback,
)
if stream:
return stream_response
@@ -219,7 +235,7 @@ def _get(key, default):
tokens.append(chunk.token)
if chunk.finish_reason:
final_chunk = chunk
-
+
return CompletionResponse(
text=text,
tokens=tokens,
@@ -228,4 +244,4 @@ def _get(key, default):
prompt_tps=final_chunk.prompt_tps,
prompt_tokens=final_chunk.prompt_tokens,
generation_tokens=final_chunk.generation_tokens,
- )
\ No newline at end of file
+ )
diff --git a/app/models/mlx_vlm.py b/app/models/mlx_vlm.py
index 39dfbfdd..968dae29 100644
--- a/app/models/mlx_vlm.py
+++ b/app/models/mlx_vlm.py
@@ -1,14 +1,17 @@
+from collections.abc import Generator
+from dataclasses import dataclass
import os
-import torch
+from typing import Any
+
import mlx.core as mx
-from dataclasses import dataclass
-from typing import List, Dict, Union, Generator, Any
-from mlx_vlm.models.cache import make_prompt_cache
from mlx_vlm import load, stream_generate
+from mlx_vlm.models.cache import make_prompt_cache
from mlx_vlm.video_generate import process_vision_info
-from ..utils.prompt_cache import LRUPromptCache
from outlines.processors import JSONLogitsProcessor
+import torch
+
from ..utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer
+from ..utils.prompt_cache import LRUPromptCache
# Default model parameters
DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "100000"))
@@ -17,6 +20,7 @@
DEFAULT_SEED = int(os.getenv("DEFAULT_SEED", "0"))
DEFAULT_REPETITION_CONTEXT_SIZE = int(os.getenv("DEFAULT_REPETITION_CONTEXT_SIZE", "20"))
+
@dataclass
class CompletionResponse:
"""
@@ -33,97 +37,99 @@ class CompletionResponse:
"""
text: str = None
- tokens: List[int] = None
+ tokens: list[int] = None
peak_memory: float = None
generation_tps: float = None
prompt_tps: float = None
prompt_tokens: int = None
generation_tokens: int = None
+
class MLX_VLM:
"""
A wrapper class for MLX Multimodal Model that handles both streaming and non-streaming inference.
-
+
This class provides a unified interface for generating text responses from images and text prompts,
supporting both streaming and non-streaming modes.
"""
-
- def __init__(self, model_path: str, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None):
+
+ def __init__(
+ self,
+ model_path: str,
+ context_length: int | None = None,
+ trust_remote_code: bool = False,
+ chat_template_file: str = None,
+ ):
"""
Initialize the MLX_VLM model.
-
+
Args:
model_path (str): Path to the model directory containing model weights and configuration.
context_length (int | None): Maximum context length for the model. If None, uses model default.
trust_remote_code (bool): Enable trust_remote_code when loading models. Defaults to False.
+
Raises:
ValueError: If model loading fails.
"""
try:
- self.model, self.processor = load(model_path, lazy=False, trust_remote_code=trust_remote_code)
+ self.model, self.processor = load(
+ model_path, lazy=False, trust_remote_code=trust_remote_code
+ )
self.config = self.model.config
self.context_length = context_length
self.outlines_tokenizer = OutlinesTransformerTokenizer(self.processor.tokenizer)
if chat_template_file:
if not os.path.exists(chat_template_file):
raise ValueError(f"Chat template file {chat_template_file} does not exist")
- with open(chat_template_file, "r") as f:
+ with open(chat_template_file) as f:
self.processor.chat_template = f.read()
except Exception as e:
- raise ValueError(f"Error loading model: {str(e)}")
+ raise ValueError(f"Error loading model: {e!s}")
def _is_video_model(self):
- return hasattr(self.config, "video_token_id") or hasattr(
- self.config, "video_token_index"
- )
+ return hasattr(self.config, "video_token_id") or hasattr(self.config, "video_token_index")
def get_model_type(self):
return self.config.model_type
- def create_prompt_cache(self) -> List[Any]:
+ def create_prompt_cache(self) -> list[Any]:
return make_prompt_cache(self.model.language_model, max_kv_size=self.context_length)
- def create_input_prompt(self, messages: List[Dict[str, str]], chat_template_kwargs: Dict[str, Any]) -> str:
+ def create_input_prompt(
+ self, messages: list[dict[str, str]], chat_template_kwargs: dict[str, Any]
+ ) -> str:
chat_template_kwargs.pop("_partial_mode", None)
return self.processor.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=True,
- **chat_template_kwargs
+ messages, tokenize=False, add_generation_prompt=True, **chat_template_kwargs
)
- def create_inputs(self, text: str, images: List[str] = None, videos: List[str] = None) -> Dict[str, Any]:
+ def create_inputs(
+ self, text: str, images: list[str] = None, videos: list[str] = None
+ ) -> dict[str, Any]:
return self.processor(
- text=[text],
- images=images,
- videos=videos,
- padding=True,
- return_tensors="pt"
+ text=[text], images=images, videos=videos, padding=True, return_tensors="pt"
)
def __call__(
- self,
- prompt: str,
- prompt_cache: List[Any] = None,
- stream: bool = False,
- **kwargs
- ) -> Union[CompletionResponse, Generator[str, None, None]]:
+ self, prompt: str, prompt_cache: list[Any] = None, stream: bool = False, **kwargs
+ ) -> CompletionResponse | Generator[str, None, None]:
"""
Generate text response from images and messages.
-
+
Args:
prompt (str): The input prompt text.
prompt_cache (List[Any], optional): Prompt cache for faster inference.
stream (bool, optional): Whether to stream the response. Defaults to False.
**kwargs: Additional model parameters (temperature, max_tokens, etc.)
- schema (dict, optional): JSON schema for structured output generation.
-
+
Returns:
- Union[CompletionResponse, Generator[str, None, None]]:
+ Union[CompletionResponse, Generator[str, None, None]]:
- If stream=False: Complete response as CompletionResponse
- If stream=True: Generator yielding response chunks
"""
+
def _get(key, default):
v = kwargs.get(key)
return default if v is None else v
@@ -139,9 +145,7 @@ def _get(key, default):
if json_schema:
logits_processors.append(
JSONLogitsProcessor(
- schema=json_schema,
- tokenizer=self.outlines_tokenizer,
- tensor_library_name="mlx"
+ schema=json_schema, tokenizer=self.outlines_tokenizer, tensor_library_name="mlx"
)
)
@@ -153,10 +157,12 @@ def _get(key, default):
"max_tokens": max_tokens,
"temperature": _get("temperature", DEFAULT_TEMPERATURE),
"repetition_penalty": kwargs.get("repetition_penalty"),
- "repetition_context_size": _get("repetition_context_size", DEFAULT_REPETITION_CONTEXT_SIZE),
+ "repetition_context_size": _get(
+ "repetition_context_size", DEFAULT_REPETITION_CONTEXT_SIZE
+ ),
"top_p": _get("top_p", DEFAULT_TOP_P),
}
-
+
model_params.update(sampling_params)
response_generator = stream_generate(
@@ -164,8 +170,8 @@ def _get(key, default):
self.processor,
prompt=prompt,
prompt_cache=prompt_cache,
- logits_processors = logits_processors,
- **model_params
+ logits_processors=logits_processors,
+ **model_params,
)
if stream:
@@ -199,19 +205,21 @@ def _get(key, default):
model = MLX_VLM(model_path, context_length=2048)
- tools = [{
- "type": "function",
- "function": {
- "name": "get_weather",
- "description": "Get the weather for a given city",
- "parameters": {
- "type": "object",
- "properties": {
- "city": {"type": "string", "description": "The city to get the weather for"}
- }
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get the weather for a given city",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "city": {"type": "string", "description": "The city to get the weather for"}
+ },
+ },
+ "required": ["city"],
},
- "required": ["city"]
- }}
+ }
]
chat_template_kwargs = {
"tools": tools,
@@ -221,15 +229,9 @@ def _get(key, default):
{
"role": "user",
"content": [
- {
- "type": "text",
- "text": "What is the weather in New York?"
- },
- {
- "type": "image",
- "image": image_path
- }
- ]
+ {"type": "text", "text": "What is the weather in New York?"},
+ {"type": "image", "image": image_path},
+ ],
}
]
prompt_cache = LRUPromptCache()
@@ -252,11 +254,11 @@ def _get(key, default):
"repetition_penalty": None,
"repetition_context_size": 20,
}
-
+
inputs.update(sampling_params)
inputs["schema"] = None
response = model(input_prompt, stream=False, **inputs)
-
- print("RESPONSE: ", response)
\ No newline at end of file
+
+ print("RESPONSE: ", response)
diff --git a/app/models/mlx_whisper.py b/app/models/mlx_whisper.py
index 61677337..f4afd89c 100644
--- a/app/models/mlx_whisper.py
+++ b/app/models/mlx_whisper.py
@@ -1,7 +1,8 @@
-import librosa
-import numpy as np
from functools import lru_cache
+
+import librosa
from mlx_whisper.transcribe import transcribe
+import numpy as np
SAMPLING_RATE = 16000
CHUNK_SIZE = 30
@@ -13,12 +14,14 @@ def load_audio(fname):
a, _ = librosa.load(fname, sr=SAMPLING_RATE, dtype=np.float32)
return a
+
@lru_cache(maxsize=32)
def calculate_audio_duration(audio_path: str) -> int:
"""Calculate the duration of the audio file in seconds."""
audio = load_audio(audio_path)
return len(audio) / SAMPLING_RATE
+
class MLX_Whisper:
def __init__(self, model_path: str):
self.model_path = model_path
@@ -28,32 +31,32 @@ def _transcribe_generator(self, audio_path: str, **kwargs):
# Load the audio file
audio = load_audio(audio_path)
duration = calculate_audio_duration(audio_path)
-
+
beg = 0.0
while beg < duration:
# Calculate chunk boundaries
chunk_end = min(beg + CHUNK_SIZE, duration)
-
+
# Extract audio chunk
beg_samples = int(beg * SAMPLING_RATE)
end_samples = int(chunk_end * SAMPLING_RATE)
audio_chunk = audio[beg_samples:end_samples]
-
+
# Transcribe chunk
result = transcribe(audio_chunk, path_or_hf_repo=self.model_path, **kwargs)
-
+
# Add timing information
result["chunk_start"] = beg
result["chunk_end"] = chunk_end
-
+
yield result
-
+
beg += CHUNK_SIZE
def __call__(self, audio_path: str, stream: bool = False, **kwargs):
"""
Transcribe audio file.
-
+
Args:
audio_path: Path to audio file
stream: If True, yields chunks. If False, transcribes entire file at once.
@@ -61,13 +64,12 @@ def __call__(self, audio_path: str, stream: bool = False, **kwargs):
"""
if stream:
return self._transcribe_generator(audio_path, **kwargs)
- else:
- return transcribe(audio_path, path_or_hf_repo=self.model_path, **kwargs)
-
-
+ return transcribe(audio_path, path_or_hf_repo=self.model_path, **kwargs)
+
+
if __name__ == "__main__":
model = MLX_Whisper("mlx-community/whisper-tiny")
# Non-streaming (fastest for most use cases)
result = model("examples/audios/podcast.wav", stream=True)
for chunk in result:
- print(f"[{chunk['chunk_start']:.1f}s - {chunk['chunk_end']:.1f}s]: {chunk['text']}")
\ No newline at end of file
+ print(f"[{chunk['chunk_start']:.1f}s - {chunk['chunk_end']:.1f}s]: {chunk['text']}")
diff --git a/app/parsers/abstract_parser.py b/app/parsers/abstract_parser.py
index aecdf927..a5331626 100644
--- a/app/parsers/abstract_parser.py
+++ b/app/parsers/abstract_parser.py
@@ -248,14 +248,13 @@ def extract_tool_calls_streaming(self, chunk: str) -> tuple[dict[str, list] | st
- extracted_content: Tool calls dict, passthrough chunk, or None
- is_complete: True if chunk should be sent, False if buffering
"""
+
def _merge_content_payload(
payload: dict[str, list] | None,
leading_content: str = "",
trailing_content: str = "",
) -> dict[str, list | str]:
- merged: dict[str, list | str] = (
- dict(payload) if isinstance(payload, dict) else {}
- )
+ merged: dict[str, list | str] = dict(payload) if isinstance(payload, dict) else {}
pieces: list[str] = []
if leading_content:
pieces.append(leading_content)
diff --git a/app/parsers/function_parameter.py b/app/parsers/function_parameter.py
index d56d28f3..3a1538bf 100644
--- a/app/parsers/function_parameter.py
+++ b/app/parsers/function_parameter.py
@@ -64,7 +64,9 @@ def __init__(
)
@staticmethod
- def _coerce_parameter_value(param_value: str) -> str | int | float | bool | list[Any] | dict[str, Any]:
+ def _coerce_parameter_value(
+ param_value: str,
+ ) -> str | int | float | bool | list[Any] | dict[str, Any]:
"""Parse tool argument values as JSON when possible."""
try:
loaded = json.loads(param_value)
@@ -121,7 +123,9 @@ def _extract_tool_calls_permissive(self, model_output: str) -> list[dict[str, st
function_content = block[function_content_start:function_close_idx]
arguments: dict[str, str | int | float | bool | list[Any] | dict[str, Any]] = {}
- for param_name_raw, param_value_raw in self.permissive_parameter_regex.findall(function_content):
+ for param_name_raw, param_value_raw in self.permissive_parameter_regex.findall(
+ function_content
+ ):
param_name = param_name_raw.strip()
param_value = param_value_raw.strip()
arguments[param_name] = self._coerce_parameter_value(param_value)
diff --git a/app/parsers/functiongemma.py b/app/parsers/functiongemma.py
index 97df74eb..d0bcf779 100644
--- a/app/parsers/functiongemma.py
+++ b/app/parsers/functiongemma.py
@@ -1,7 +1,7 @@
from __future__ import annotations
-import re
import json
+import re
from .abstract_parser import AbstractToolParser
@@ -32,12 +32,12 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) ->
def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
"""Extract tool calls from complete model output.
-
+
Parameters
----------
model_output : str
Complete model output containing tool calls.
-
+
Returns
-------
dict[str, list] | None
@@ -46,9 +46,7 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
"""
matches = self.tool_call_regex.findall(model_output)
if not matches:
- return {
- "content": model_output
- }
+ return {"content": model_output}
tool_calls = []
for match in matches:
function_name = match[0]
@@ -59,4 +57,3 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
return {
"tool_calls": tool_calls,
}
-
diff --git a/app/parsers/glm4_moe.py b/app/parsers/glm4_moe.py
index 1ef433be..1fa0b54e 100644
--- a/app/parsers/glm4_moe.py
+++ b/app/parsers/glm4_moe.py
@@ -1,7 +1,7 @@
from __future__ import annotations
-import re
import json
+import re
from .abstract_parser import AbstractToolParser
from .hermes import HermesReasoningParser
@@ -19,13 +19,15 @@ class GLM4MoEReasoningParser(HermesReasoningParser):
reasoning_content
"""
- def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None:
+ def __init__(
+ self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE
+ ) -> None:
"""Initialize the Hermes4 reasoning parser with appropriate regex patterns."""
super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close)
-
+
def respects_enable_thinking(self) -> bool:
"""Check if the reasoning parser respects the enable_thinking flag.
-
+
Returns
-------
bool
@@ -52,7 +54,7 @@ class GLM4MoEToolParser(AbstractToolParser):
def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> None:
"""Initialize the GLM4 MoE tool parser with appropriate regex patterns."""
super().__init__(tool_open=tool_open, tool_close=tool_close)
-
+
self.func_call_regex = re.compile(r".*?", re.DOTALL)
self.func_detail_regex = re.compile(
r"(.*?)(.*?)?", re.DOTALL
@@ -64,12 +66,12 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) ->
def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
"""Extract tool calls from complete model output.
-
+
Parameters
----------
model_output : str
Complete model output containing tool calls in JSON format.
-
+
Returns
-------
dict[str, list] | None
@@ -78,9 +80,7 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
"""
matches = self.func_call_regex.findall(model_output)
if not matches:
- return {
- "content": model_output
- }
+ return {"content": model_output}
tool_calls = []
for match in matches:
tc_detail = self.func_detail_regex.search(match)
@@ -92,10 +92,7 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
arg_key = key.strip()
arg_value = value.strip()
arg_dct[arg_key] = arg_value
- tool_calls.append({
- "name": tc_name.strip(),
- "arguments": json.dumps(arg_dct, ensure_ascii=False)
- })
- return {
- "tool_calls": tool_calls
- }
+ tool_calls.append(
+ {"name": tc_name.strip(), "arguments": json.dumps(arg_dct, ensure_ascii=False)}
+ )
+ return {"tool_calls": tool_calls}
diff --git a/app/parsers/harmony.py b/app/parsers/harmony.py
index 838600c5..ad07b71e 100644
--- a/app/parsers/harmony.py
+++ b/app/parsers/harmony.py
@@ -1,26 +1,26 @@
from __future__ import annotations
from enum import Enum
-from openai_harmony import (
- load_harmony_encoding,
- HarmonyEncodingName,
- StreamableParser,
- Role
-)
+
+from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding
+
class ChannelType(Enum):
"""Enumeration of harmony channel types."""
ANALYSIS = "analysis"
- COMMENTARY = "commentary"
+ COMMENTARY = "commentary"
FINAL = "final"
+
class ToolParserState(Enum):
"""Enumeration of parser states."""
+
NORMAL = "normal"
FOUND_ARGUMENTS = "found_arguments"
END_STREAM = "end_stream"
+
class HarmonyParser:
"""Parser for Harmony encoding."""
@@ -39,7 +39,7 @@ def parse(self, text: str) -> dict[str, str] | None:
"""
if self.end_tool_chunk in text:
end_tool_index = text.find(self.end_tool_chunk)
- text = text[:end_tool_index + len(self.end_tool_chunk)]
+ text = text[: end_tool_index + len(self.end_tool_chunk)]
result = {
"content": None,
@@ -47,24 +47,28 @@ def parse(self, text: str) -> dict[str, str] | None:
"reasoning_content": None,
}
tokens = self.encoding.encode(text, allowed_special="all")
- parsed_messages = self.encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT)
+ parsed_messages = self.encoding.parse_messages_from_completion_tokens(
+ tokens, role=Role.ASSISTANT
+ )
for message in parsed_messages:
if message.channel == ChannelType.ANALYSIS.value:
result["reasoning_content"] = message.content[0].text
elif message.channel == ChannelType.COMMENTARY.value:
- result["tool_calls"].append({
- "name": message.recipient.replace("functions.", ""),
- "arguments": message.content[0].text
- })
+ result["tool_calls"].append(
+ {
+ "name": message.recipient.replace("functions.", ""),
+ "arguments": message.content[0].text,
+ }
+ )
elif message.channel == ChannelType.FINAL.value:
result["content"] = message.content[0].text
return result
-
+
def _build_result(
- self,
- reasoning_contents: list[str],
- tool_calls: list[dict[str, str]] | None,
- contents: list[str]
+ self,
+ reasoning_contents: list[str],
+ tool_calls: list[dict[str, str]] | None,
+ contents: list[str],
) -> dict[str, str | list | None]:
"""Build the result dictionary from accumulated content."""
return {
@@ -72,7 +76,7 @@ def _build_result(
"tool_calls": tool_calls,
"content": "".join(contents) or None,
}
-
+
def parse_streaming(self, chunk: str) -> tuple[dict[str, str | list | None] | None, bool]:
"""Parse the chunk and return the parsed content."""
if self.state == ToolParserState.END_STREAM:
@@ -85,9 +89,9 @@ def parse_streaming(self, chunk: str) -> tuple[dict[str, str | list | None] | No
# Check for end marker and truncate
if self.end_tool_chunk in chunk:
end_tool_index = chunk.find(self.end_tool_chunk)
- chunk = chunk[:end_tool_index + len(self.end_tool_chunk)]
+ chunk = chunk[: end_tool_index + len(self.end_tool_chunk)]
end_stream_state = True
-
+
# Process chunk tokens
chunk_tokens = self.encoding.encode(chunk, allowed_special="all")
for chunk_token in chunk_tokens:
@@ -115,15 +119,14 @@ def parse_streaming(self, chunk: str) -> tuple[dict[str, str | list | None] | No
# Handle end of stream
if end_stream_state:
- tool_calls = [{
- "name": self.function_name_buffer,
- "arguments": "".join(self.arguments_buffer)
- }]
+ tool_calls = [
+ {"name": self.function_name_buffer, "arguments": "".join(self.arguments_buffer)}
+ ]
self.arguments_buffer = []
self.function_name_buffer = ""
self.state = ToolParserState.END_STREAM
return self._build_result(reasoning_contents, tool_calls, contents), True
-
+
return self._build_result(reasoning_contents, None, contents), False
def handle_parse_streaming_end(self) -> tuple[dict[str, str | list | None] | None, bool]:
diff --git a/app/parsers/hermes.py b/app/parsers/hermes.py
index a225a5ff..3452a230 100644
--- a/app/parsers/hermes.py
+++ b/app/parsers/hermes.py
@@ -1,7 +1,7 @@
from __future__ import annotations
-import re
import json
+import re
from .abstract_parser import (
AbstractReasoningParser,
diff --git a/app/parsers/kimi_k2.py b/app/parsers/kimi_k2.py
index 455933b3..bcedfe08 100644
--- a/app/parsers/kimi_k2.py
+++ b/app/parsers/kimi_k2.py
@@ -1,8 +1,9 @@
from __future__ import annotations
-import re
-import json
from enum import Enum
+import json
+import re
+
from .hermes import HermesToolParser
TOOL_CALL_SECTION_BEGIN = "<|tool_calls_section_begin|>"
@@ -18,17 +19,20 @@ class KimiK2ToolState(Enum):
NORMAL = "normal"
FOUND_TOOL_SECTION = "found_tool_section"
+
class KimiK2ToolParser(HermesToolParser):
"""Kimi K2 tool parser.
-
+
Handles tool calls in format:
<|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "New York"}<|tool_call_end|><|tool_calls_section_end|>
"""
- def __init__(self, tool_open: str = TOOL_CALL_SECTION_BEGIN, tool_close: str = TOOL_CALL_SECTION_END) -> None:
+ def __init__(
+ self, tool_open: str = TOOL_CALL_SECTION_BEGIN, tool_close: str = TOOL_CALL_SECTION_END
+ ) -> None:
"""Initialize Solar Open tool parser."""
super().__init__(tool_open=tool_open, tool_close=tool_close)
-
+
self.state = KimiK2ToolState.NORMAL
self.tool_call_section_regex = re.compile(
re.escape(self.tool_open) + r"(.*?)" + re.escape(self.tool_close),
@@ -83,4 +87,4 @@ def extract_tool_calls(self, tool_output: str) -> dict[str, list] | None:
tool_calls.append({"name": name, "arguments": json.dumps(arguments)})
if not tool_calls:
return None
- return {"tool_calls": tool_calls}
\ No newline at end of file
+ return {"tool_calls": tool_calls}
diff --git a/app/parsers/longcat_flash_lite.py b/app/parsers/longcat_flash_lite.py
index 138b2c74..857fd549 100644
--- a/app/parsers/longcat_flash_lite.py
+++ b/app/parsers/longcat_flash_lite.py
@@ -1,11 +1,13 @@
from __future__ import annotations
import re
+
from .glm4_moe import GLM4MoEToolParser
TOOL_OPEN = ""
TOOL_CLOSE = ""
+
class LongCatFlashLiteToolParser(GLM4MoEToolParser):
"""Tool parser for LongCat Flash Lite model's tool response format.
@@ -22,7 +24,7 @@ class LongCatFlashLiteToolParser(GLM4MoEToolParser):
def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> None:
"""Initialize the LongCat Flash Lite tool parser with appropriate regex patterns."""
super().__init__(tool_open=tool_open, tool_close=tool_close)
-
+
self.func_call_regex = re.compile(r".*?", re.DOTALL)
self.func_detail_regex = re.compile(
r"(.*?)(.*?)?", re.DOTALL
@@ -30,4 +32,4 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) ->
self.func_arg_regex = re.compile(
r"(.*?)(?:\\n|\s)*(.*?)",
re.DOTALL,
- )
\ No newline at end of file
+ )
diff --git a/app/parsers/minimax_m2.py b/app/parsers/minimax_m2.py
index ca8a4cb0..5f0076b9 100644
--- a/app/parsers/minimax_m2.py
+++ b/app/parsers/minimax_m2.py
@@ -24,15 +24,13 @@ class MiniMaxM2ToolParser(GLM4MoEToolParser):
def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> None:
"""Initialize the MiniMax M2 tool parser with appropriate regex patterns."""
super().__init__(tool_open=tool_open, tool_close=tool_close)
-
+
self.func_call_regex = re.compile(
r"(.*?)", re.DOTALL
)
-
+
# Regex patterns for parsing MiniMax tool calls
- self.func_detail_regex = re.compile(
- r'(.*)', re.DOTALL
- )
+ self.func_detail_regex = re.compile(r'(.*)', re.DOTALL)
self.func_arg_regex = re.compile(
r'(.*?)', re.DOTALL
- )
\ No newline at end of file
+ )
diff --git a/app/parsers/qwen3.py b/app/parsers/qwen3.py
index 7a7bc224..12a252bd 100644
--- a/app/parsers/qwen3.py
+++ b/app/parsers/qwen3.py
@@ -13,16 +13,18 @@ class Qwen3ReasoningParser(HermesReasoningParser):
reasoning_content
"""
- def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None:
+ def __init__(
+ self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE
+ ) -> None:
"""Initialize the Qwen3 reasoning parser with appropriate regex patterns."""
super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close)
def respects_enable_thinking(self) -> bool:
"""Check if the reasoning parser respects the enable_thinking flag.
-
+
Returns
-------
bool
True if the reasoning parser respects the enable_thinking flag, False otherwise.
"""
- return True
\ No newline at end of file
+ return True
diff --git a/app/parsers/qwen3_5.py b/app/parsers/qwen3_5.py
index e1a7263c..3b15564f 100644
--- a/app/parsers/qwen3_5.py
+++ b/app/parsers/qwen3_5.py
@@ -13,16 +13,18 @@ class Qwen35ReasoningParser(Qwen3MoEReasoningParser):
reasoning_content
"""
- def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None:
+ def __init__(
+ self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE
+ ) -> None:
"""Initialize the Qwen3.5 reasoning parser with appropriate regex patterns."""
super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close)
def respects_enable_thinking(self) -> bool:
"""Check if the reasoning parser respects the enable_thinking flag.
-
+
Returns
-------
bool
True if the reasoning parser respects the enable_thinking flag, False otherwise.
"""
- return True
\ No newline at end of file
+ return True
diff --git a/app/parsers/qwen3_moe.py b/app/parsers/qwen3_moe.py
index a1e6da8c..ee0548fa 100644
--- a/app/parsers/qwen3_moe.py
+++ b/app/parsers/qwen3_moe.py
@@ -15,16 +15,18 @@ class Qwen3MoEReasoningParser(HermesReasoningParser):
reasoning_content
"""
- def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None:
+ def __init__(
+ self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE
+ ) -> None:
"""Initialize the Qwen3 MoE reasoning parser with appropriate regex patterns."""
super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close)
def needs_redacted_reasoning_prefix(self) -> bool:
"""Check if the reasoning parser needs a redacted reasoning prefix.
-
+
Returns
-------
bool
True if the reasoning parser needs a redacted reasoning prefix, False otherwise.
"""
- return True
\ No newline at end of file
+ return True
diff --git a/app/parsers/solar_open.py b/app/parsers/solar_open.py
index 01d5fcf6..a4463ef1 100644
--- a/app/parsers/solar_open.py
+++ b/app/parsers/solar_open.py
@@ -1,12 +1,12 @@
from __future__ import annotations
-import json
from enum import Enum
+import json
from loguru import logger
-from .hermes import HermesReasoningParser
from .abstract_parser import AbstractToolParser
+from .hermes import HermesReasoningParser
REASONING_OPEN = "<|think|>"
REASONING_CLOSE = "<|end|>"
@@ -30,18 +30,20 @@ class SolarOpenToolState(Enum):
class SolarOpenReasoningParser(HermesReasoningParser):
"""Solar Open reasoning parser.
-
+
Handles reasoning content in format: <|think|>reasoning_content<|end|>
"""
- def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None:
+ def __init__(
+ self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE
+ ) -> None:
"""Initialize Solar Open reasoning parser."""
super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close)
class SolarOpenToolParser(AbstractToolParser):
"""Solar Open tool parser.
-
+
Handles tool calls in format:
<|tool_call:begin|><|tool_call:name|><|tool_call:args|><|tool_call:end|>
"""
@@ -56,12 +58,12 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) ->
def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
"""Extract tool calls from complete model output.
-
+
Parameters
----------
model_output : str
Complete model output containing tool calls.
-
+
Returns
-------
dict[str, list] | None
@@ -79,22 +81,12 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
while self.tool_open in remaining_output:
tool_call_open_idx = remaining_output.find(self.tool_open)
- tool_call_name_idx = remaining_output.find(
- self.tool_name_prefix, tool_call_open_idx
- )
- tool_call_args_idx = remaining_output.find(
- self.tool_args_prefix, tool_call_name_idx
- )
- tool_call_close_idx = remaining_output.find(
- self.tool_close, tool_call_args_idx
- )
+ tool_call_name_idx = remaining_output.find(self.tool_name_prefix, tool_call_open_idx)
+ tool_call_args_idx = remaining_output.find(self.tool_args_prefix, tool_call_name_idx)
+ tool_call_close_idx = remaining_output.find(self.tool_close, tool_call_args_idx)
# Validate all required tokens were found
- if (
- tool_call_name_idx == -1
- or tool_call_args_idx == -1
- or tool_call_close_idx == -1
- ):
+ if tool_call_name_idx == -1 or tool_call_args_idx == -1 or tool_call_close_idx == -1:
logger.warning(
f"Malformed tool call in output, missing required tokens: {remaining_output[:100]}"
)
@@ -113,31 +105,25 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None:
json.loads(tool_args) # Validate JSON format
tool_calls.append({"name": tool_name, "arguments": tool_args})
except json.JSONDecodeError as e:
- logger.warning(
- f"Invalid JSON in tool arguments for '{tool_name}': {e}"
- )
+ logger.warning(f"Invalid JSON in tool arguments for '{tool_name}': {e}")
# Skip this malformed tool call and continue
# Move past this tool call
- remaining_output = remaining_output[
- tool_call_close_idx + len(self.tool_close) :
- ]
+ remaining_output = remaining_output[tool_call_close_idx + len(self.tool_close) :]
return {
"tool_calls": tool_calls,
"content": remaining_output if remaining_output else None,
}
- def extract_tool_calls_streaming(
- self, chunk: str
- ) -> tuple[dict[str, list] | str | None, bool]:
+ def extract_tool_calls_streaming(self, chunk: str) -> tuple[dict[str, list] | str | None, bool]:
"""Extract tool calls from streaming chunks.
-
+
Parameters
----------
chunk : str
Chunk of model output to process.
-
+
Returns
-------
tuple[dict[str, list] | str | None, bool]
@@ -176,4 +162,4 @@ def extract_tool_calls_streaming(
return None, False
# Normal state - keep buffering
- return None, False
\ No newline at end of file
+ return None, False
diff --git a/app/schemas/openai.py b/app/schemas/openai.py
index 36a6086c..035d860d 100644
--- a/app/schemas/openai.py
+++ b/app/schemas/openai.py
@@ -2,18 +2,19 @@
from __future__ import annotations
-import uuid
-import time
from enum import Enum
+import time
from typing import Any, ClassVar, Literal, TypeAlias
+import uuid
-from loguru import logger
from fastapi import UploadFile
+from loguru import logger
from pydantic import BaseModel, ConfigDict, Field, model_validator
+
class OpenAIBaseModel(BaseModel):
"""Base model for OpenAI API schemas."""
-
+
# OpenAI API does allow extra fields
model_config = ConfigDict(extra="allow")
@@ -154,9 +155,10 @@ class PromptTokenUsageInfo(OpenAIBaseModel):
cached_tokens: int | None = None
+
class StreamOptions(OpenAIBaseModel):
"""Stream options for a request."""
-
+
include_usage: bool | None = True
continuous_usage_stats: bool | None = False
@@ -178,15 +180,16 @@ class FunctionCall(OpenAIBaseModel):
name: str
arguments: str
+
def random_uuid() -> str:
return str(uuid.uuid4().hex)
+
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
if id_type == "kimi_k2":
return f"functions.{func_name}:{idx}"
- else:
- # by default return random
- return f"chatcmpl-tool-{random_uuid()}"
+ # by default return random
+ return f"chatcmpl-tool-{random_uuid()}"
class ChatCompletionMessageToolCall(OpenAIBaseModel):
@@ -224,6 +227,7 @@ class Message(OpenAIBaseModel):
"from this message instead of starting a new assistant turn (prefill / partial mode).",
)
+
class ChatTemplateKwargs(OpenAIBaseModel):
"""Represents the arguments for a chat template."""
@@ -232,44 +236,45 @@ class ChatTemplateKwargs(OpenAIBaseModel):
default="medium", description="The reasoning effort level."
)
+
class FunctionDefinition(OpenAIBaseModel):
name: str
description: str | None = None
parameters: dict[str, Any] | None = None
+
class ChatCompletionToolsParam(OpenAIBaseModel):
type: Literal["function"] = "function"
function: FunctionDefinition
+
class ChatCompletionNamedFunction(OpenAIBaseModel):
name: str
+
class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
function: ChatCompletionNamedFunction
type: Literal["function"] = "function"
-class ChatCompletionRequest(OpenAIBaseModel):
+class ChatCompletionRequest(OpenAIBaseModel):
"""Request schema for OpenAI-compatible chat completion API."""
-
+
model: str = Field(Config.TEXT_MODEL, description="The model to use for completion.")
messages: list[Message] = Field(..., description="The list of messages in the conversation.")
tools: list[ChatCompletionToolsParam] | None = Field(
None, description="List of tools available for the request."
)
- tool_choice: (
- Literal["none"]
- | Literal["auto"]
- | Literal["required"]
- | ChatCompletionNamedToolChoiceParam
- | None
- ) = "none"
+ tool_choice: Literal["none", "auto", "required"] | ChatCompletionNamedToolChoiceParam | None = (
+ "none"
+ )
max_tokens: int | None = Field(
default=None,
- deprecated="max_tokens is deprecated in favor of "
- "the max_completion_tokens field",
+ deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
+ )
+ max_completion_tokens: int | None = Field(
+ None, description="Maximum number of tokens to generate."
)
- max_completion_tokens: int | None = Field(None, description="Maximum number of tokens to generate.")
temperature: float | None = Field(None, description="Sampling temperature.")
top_p: float | None = Field(None, description="Nucleus sampling probability.")
top_k: int | None = Field(None, description="Top-k sampling parameter.")
@@ -284,7 +289,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
n: int | None = Field(None, description="Number of completions to generate.")
response_format: dict[str, Any] | None = Field(None, description="Format for the response.")
seed: int | None = Field(
- None, description="The seed to use for sampling.",
+ None,
+ description="The seed to use for sampling.",
)
user: str | None = Field(None, description="User identifier.")
repetition_penalty: float | None = Field(
@@ -296,11 +302,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
xtc_probability: float | None = Field(
None, description="XTC (eXclude Top Choices) sampling probability (0.0-1.0)."
)
- xtc_threshold: float | None = Field(
- None, description="XTC sampling threshold (0.0-0.5)."
- )
+ xtc_threshold: float | None = Field(None, description="XTC sampling threshold (0.0-0.5).")
logit_bias: dict[str, float] | None = Field(
- None, description="Modify the likelihood of specified tokens appearing in the completion. Maps token IDs (as strings) to bias values from -100 to 100."
+ None,
+ description="Modify the likelihood of specified tokens appearing in the completion. Maps token IDs (as strings) to bias values from -100 to 100.",
)
stream: bool = Field(False, description="Whether to stream the response.")
stream_options: StreamOptions | None = None
@@ -452,6 +457,7 @@ class ImageSize(str, Enum):
MEDIUM = "512x512"
LARGE = "1024x1024"
+
class Priority(str, Enum):
"""Task priority levels."""
@@ -485,13 +491,9 @@ class TranscriptionResponseFormat(str, Enum):
class ImageGenerationRequest(OpenAIBaseModel):
"""Request schema for OpenAI-compatible image generation API."""
- prompt: str = Field(
- ...,
- description="A text description of the desired image(s)."
- )
+ prompt: str = Field(..., description="A text description of the desired image(s).")
negative_prompt: str | None = Field(
- None,
- description="A text description of the desired image(s)."
+ None, description="A text description of the desired image(s)."
)
model: str | None = Field(
default=Config.IMAGE_GENERATION_MODEL, description="The model to use for image generation"
@@ -546,7 +548,9 @@ class ImageGenerationError(OpenAIBaseModel):
class ImageEditRequest(OpenAIBaseModel):
"""Request data for OpenAI-compatible image edit API."""
- image: UploadFile | list[UploadFile] = Field(..., description="The image(s) to edit. Must be a file upload or a list of file uploads")
+ image: UploadFile | list[UploadFile] = Field(
+ ..., description="The image(s) to edit. Must be a file upload or a list of file uploads"
+ )
prompt: str = Field(..., description="The prompt for the image edit")
model: str | None = Field(
default=Config.IMAGE_EDIT_MODEL, description="The model to use for image edit"
@@ -650,14 +654,9 @@ class TranscriptionResponseStream(OpenAIBaseModel):
# --- Responses API Schemas ---
-from openai.types.responses import (
- ResponseStatus,
- ResponseInputItemParam,
- ResponseOutputItem,
-)
+from openai.types.responses import ResponseInputItemParam, ResponseOutputItem, ResponseStatus
+from openai.types.responses.response import IncompleteDetails, Tool, ToolChoice
from openai.types.shared import Reasoning
-from openai.types.responses.response import Tool, ToolChoice
-from openai.types.responses.response import IncompleteDetails
ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem
@@ -696,9 +695,7 @@ class ResponsesRequest(OpenAIBaseModel):
)
seed: int | None = Field(None, description="The seed for the response.")
text: ResponseTextConfig | None = None
- tools: list[Tool] | None = Field(
- None, description="List of tools to use for the response."
- )
+ tools: list[Tool] | None = Field(None, description="List of tools to use for the response.")
tool_choice: ToolChoice | None = Field(
default="auto", description="The tool choice to use for the response."
)
@@ -717,13 +714,14 @@ class OutputTokensDetails(OpenAIBaseModel):
output_tokens_per_turn: list[int] = Field(default_factory=list)
tool_output_tokens_per_turn: list[int] = Field(default_factory=list)
+
class ResponseUsage(OpenAIBaseModel):
input_tokens: int
input_tokens_details: InputTokensDetails
output_tokens: int
output_tokens_details: OutputTokensDetails
total_tokens: int
-
+
class ResponsesResponse(OpenAIBaseModel):
"""Represents a complete response from the Responses API."""
diff --git a/app/server.py b/app/server.py
index 4b40ac3c..b561b5ad 100644
--- a/app/server.py
+++ b/app/server.py
@@ -57,6 +57,7 @@ def ensure_image_handler_available(model_type: str) -> None:
raise RuntimeError(MFLUX_INSTALL_HINT)
+
_SAMPLING_DEFAULT_FIELDS: tuple[str, ...] = (
"default_max_tokens",
"default_temperature",
diff --git a/app/utils/debug_logging.py b/app/utils/debug_logging.py
index 106cfd4e..5797d597 100644
--- a/app/utils/debug_logging.py
+++ b/app/utils/debug_logging.py
@@ -2,12 +2,13 @@
import time
from typing import Any
+
from loguru import logger
def log_debug_request(request_dict: dict[str, Any]) -> None:
"""Log request details in a beautiful format for debug mode.
-
+
Parameters
----------
request_dict : dict[str, Any]
@@ -16,25 +17,27 @@ def log_debug_request(request_dict: dict[str, Any]) -> None:
logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
logger.info("🔍 DEBUG: Request Details")
logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
-
+
# Extract and format key information
if "messages" in request_dict:
logger.info(f"📨 Messages: {len(request_dict['messages'])} message(s)")
for i, msg in enumerate(request_dict["messages"], 1):
role = msg.get("role", "unknown")
content = msg.get("content", "")
- content_preview = str(content)[:100] + "..." if len(str(content)) > 100 else str(content)
+ content_preview = (
+ str(content)[:100] + "..." if len(str(content)) > 100 else str(content)
+ )
logger.info(f" {i}. [{role}] {content_preview}")
-
+
if request_dict.get("max_tokens"):
logger.info(f"🎯 Max Tokens: {request_dict['max_tokens']:,}")
-
+
if request_dict.get("temperature"):
logger.info(f"🌡️ Temperature: {request_dict['temperature']}")
-
+
if request_dict.get("top_p"):
logger.info(f"🎲 Top P: {request_dict['top_p']}")
-
+
logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
@@ -46,7 +49,7 @@ def log_debug_stats(
peak_memory: float,
) -> None:
"""Log generation statistics in a beautiful format for debug mode.
-
+
Parameters
----------
prompt_tokens : int
@@ -73,7 +76,7 @@ def log_debug_stats(
def log_debug_prompt(prompt: str) -> None:
"""Log input prompt in a beautiful format for debug mode.
-
+
Parameters
----------
prompt : str
@@ -86,7 +89,7 @@ def log_debug_prompt(prompt: str) -> None:
def log_debug_raw_text_response(raw_text: str) -> None:
"""Log raw text response in a beautiful format for debug mode.
-
+
Parameters
----------
raw_text : str
@@ -101,7 +104,7 @@ def log_debug_raw_text_response(raw_text: str) -> None:
def log_debug_cache_stats(total_input_tokens: int, remaining_tokens: int) -> None:
"""Log prompt cache statistics in a beautiful format for debug mode.
-
+
Parameters
----------
total_input_tokens : int
@@ -155,14 +158,15 @@ def log_debug_chat_template(
logger.info("✦ Using default chat template from model")
logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
+
def make_prompt_progress_callback(start_time: float | None = None) -> callable:
"""Create a callback function for tracking prompt processing progress.
-
+
Parameters
----------
start_time : float | None
The start time for calculating speed. If None, uses current time.
-
+
Returns
-------
callable
@@ -170,13 +174,13 @@ def make_prompt_progress_callback(start_time: float | None = None) -> callable:
"""
if start_time is None:
start_time = time.time()
-
+
def callback(processed: int, total_tokens: int) -> None:
"""Log prompt processing progress with speed metrics."""
elapsed = time.time() - start_time
speed = processed / elapsed if elapsed > 0 else 0
logger.info(f"⚡ Processed {processed:6d}/{total_tokens} tokens ({speed:6.2f} tok/s)")
-
+
return callback
diff --git a/app/utils/prompt_cache.py b/app/utils/prompt_cache.py
index e17a813a..c2966dbe 100644
--- a/app/utils/prompt_cache.py
+++ b/app/utils/prompt_cache.py
@@ -2,8 +2,8 @@
from __future__ import annotations
-import copy
from collections import deque
+import copy
from dataclasses import dataclass
from typing import Any
@@ -139,6 +139,7 @@ def _search(self, tokens_ids: list[int]) -> SearchResult:
----------
tokens_ids : list[int]
Token sequence to search for.
+
Returns
-------
SearchResult
@@ -182,7 +183,6 @@ def _search(self, tokens_ids: list[int]) -> SearchResult:
return self.SearchResult(None, shorter, longer, common_prefix)
-
def fetch_nearest_cache(
self,
tokens_ids: list[int],
@@ -193,6 +193,7 @@ def fetch_nearest_cache(
----------
tokens_ids : list[int]
Token sequence to find a cache for.
+
Returns
-------
tuple[list[Any] | None, list[int]]
@@ -227,6 +228,7 @@ def _get(self, tokens_ids: list[int]) -> CacheEntry:
----------
tokens_ids : list[int]
Token sequence identifying the cache entry.
+
Returns
-------
CacheEntry
@@ -344,8 +346,10 @@ def log_cache_stats(self) -> None:
latest_checkpoint_tokens,
)
+
if __name__ == "__main__":
from app.models.mlx_lm import MLX_LM
+
model_path = "mlx-community/Qwen3-Coder-Next-8bit"
draft_model_path = "mlx-community/Qwen3-Coder-Next-4bit"
model = MLX_LM(model_path, draft_model_path)
@@ -374,7 +378,6 @@ def log_cache_stats(self) -> None:
first_token = False
cache_key.append(chunk.token)
-
prompt_cache.insert_cache(cache_key, cache)
start_time = time.time()
@@ -383,7 +386,7 @@ def log_cache_stats(self) -> None:
input_prompt_2 = model.create_input_prompt([{"role": "user", "content": prompt_2}], {})
input_ids_2 = model.encode_prompt(input_prompt_2)
cache, rest_input_ids_2 = prompt_cache.fetch_nearest_cache(input_ids_2)
-
+
if cache is None:
cache = model.create_prompt_cache()
# Use full input_ids for cache_key, not rest_input_ids
@@ -402,4 +405,4 @@ def log_cache_stats(self) -> None:
print("RAW TEXT", raw_text)
- prompt_cache.insert_cache(cache_key_2, cache)
\ No newline at end of file
+ prompt_cache.insert_cache(cache_key_2, cache)
diff --git a/configure_mlx.sh b/configure_mlx.sh
index f1cfe6e6..00e1bb80 100644
--- a/configure_mlx.sh
+++ b/configure_mlx.sh
@@ -40,4 +40,4 @@ else
sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB
sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \
sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB
-fi
\ No newline at end of file
+fi
diff --git a/examples/chat_templates/Longcat-Flash-Lite.jinja b/examples/chat_templates/Longcat-Flash-Lite.jinja
index 3166cb94..45b06416 100644
--- a/examples/chat_templates/Longcat-Flash-Lite.jinja
+++ b/examples/chat_templates/Longcat-Flash-Lite.jinja
@@ -65,4 +65,4 @@
{%- endfor %}
{%- if add_generation_prompt %}
{{- "" }}
-{%- endif %}
\ No newline at end of file
+{%- endif %}
diff --git a/examples/chat_templates/llama4.jinja b/examples/chat_templates/llama4.jinja
index 386d0320..bbed3d82 100644
--- a/examples/chat_templates/llama4.jinja
+++ b/examples/chat_templates/llama4.jinja
@@ -108,4 +108,4 @@
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|header_start|>assistant<|header_end|>\n\n' }}
-{%- endif %}
\ No newline at end of file
+{%- endif %}
diff --git a/examples/config.yaml b/examples/config.yaml
index 02d011ae..d30ab8bd 100644
--- a/examples/config.yaml
+++ b/examples/config.yaml
@@ -33,4 +33,4 @@ models:
model_type: image-generation
config_name: flux2-klein-4b
quantize: 4
- model_id: flux2-klein-4b
\ No newline at end of file
+ model_id: flux2-klein-4b
diff --git a/pyproject.toml b/pyproject.toml
index 86c8c167..5c3724e8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,8 +34,8 @@ dependencies = [
"mlx-lm>=0.31.0,<0.32",
"mlx-vlm>=0.4.0,<0.5",
"mlx-whisper>=0.4.3,<0.5",
- "mlx>=0.31.0",
- "numpy>=2.2.0,<2.4",
+ "mlx>=0.31.0",
+ "numpy>=2.2.0,<2.5",
"openai>=2.21,<3",
"openai-harmony>=0.0.8,<0.1",
"outlines>=1.1.1,<1.2",
@@ -323,4 +323,4 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
module = ["untyped_package.*"]
-follow_untyped_imports = true
\ No newline at end of file
+follow_untyped_imports = true
diff --git a/tests/parsers/test_harmony.py b/tests/parsers/test_harmony.py
index f9f47ab5..c9fd8289 100644
--- a/tests/parsers/test_harmony.py
+++ b/tests/parsers/test_harmony.py
@@ -31,11 +31,11 @@ def test_harmony_reasoning_and_tool_parsing_streaming() -> None:
"json",
"<|message|>",
"{",
- "\"city\"",
+ '"city"',
":",
- "\"Tokyo\"",
+ '"Tokyo"',
"}",
- "<|call|>"
+ "<|call|>",
]
results = []
@@ -57,7 +57,7 @@ def test_harmony_reasoning_and_tool_parsing_streaming() -> None:
assert "tool_calls" in final_result
assert final_result["tool_calls"] is not None
assert len(final_result["tool_calls"]) == 1
-
+
tool_call = final_result["tool_calls"][0]
assert tool_call["name"] == "get_weather"
assert "Tokyo" in tool_call["arguments"]
@@ -79,7 +79,10 @@ def test_harmony_non_streaming_parse() -> None:
assert result is not None
assert "reasoning_content" in result
assert result["reasoning_content"] is not None
- assert "need" in result["reasoning_content"].lower() or "call" in result["reasoning_content"].lower()
+ assert (
+ "need" in result["reasoning_content"].lower()
+ or "call" in result["reasoning_content"].lower()
+ )
# Verify tool calls
assert "tool_calls" in result
@@ -87,7 +90,8 @@ def test_harmony_non_streaming_parse() -> None:
assert result["tool_calls"][0]["name"] == "get_weather"
assert "Tokyo" in result["tool_calls"][0]["arguments"]
+
if __name__ == "__main__":
test_harmony_reasoning_and_tool_parsing_streaming()
test_harmony_non_streaming_parse()
- print("All tests passed!")
\ No newline at end of file
+ print("All tests passed!")
diff --git a/tests/test_multi_model_per_model_defaults.py b/tests/test_multi_model_per_model_defaults.py
index 67a71f6f..dd70cb50 100644
--- a/tests/test_multi_model_per_model_defaults.py
+++ b/tests/test_multi_model_per_model_defaults.py
@@ -103,9 +103,7 @@ def _make_raw_request(registry: _FakeRegistry, handler: Any | None = None) -> An
"""Build a minimal request-like object for endpoint unit tests."""
return types.SimpleNamespace(
- app=types.SimpleNamespace(
- state=types.SimpleNamespace(registry=registry, handler=handler)
- ),
+ app=types.SimpleNamespace(state=types.SimpleNamespace(registry=registry, handler=handler)),
state=types.SimpleNamespace(request_id="req-test"),
)
@@ -335,7 +333,9 @@ async def _fake_process_text_request(
await endpoints_module.chat_completions(request, _make_raw_request(registry))
- assert captured_requests[0].temperature == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"]))
+ assert captured_requests[0].temperature == pytest.approx(
+ float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"])
+ )
assert captured_requests[0].top_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_P"]))
assert captured_requests[0].top_k == int(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_K"])
assert captured_requests[0].min_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_MIN_P"]))
@@ -343,7 +343,9 @@ async def _fake_process_text_request(
float(GLOBAL_ENV_DEFAULTS["DEFAULT_REPETITION_PENALTY"])
)
assert captured_requests[0].seed == int(GLOBAL_ENV_DEFAULTS["DEFAULT_SEED"])
- assert captured_requests[0].max_completion_tokens == int(GLOBAL_ENV_DEFAULTS["DEFAULT_MAX_TOKENS"])
+ assert captured_requests[0].max_completion_tokens == int(
+ GLOBAL_ENV_DEFAULTS["DEFAULT_MAX_TOKENS"]
+ )
assert captured_requests[0].xtc_probability == pytest.approx(
float(GLOBAL_ENV_DEFAULTS["DEFAULT_XTC_PROBABILITY"])
)
@@ -509,7 +511,9 @@ async def test_responses_single_model_implicit_handler_defaults_do_not_shadow_en
refined_request = endpoints_module.refine_responses_request(request, handler)
- assert refined_request.temperature == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"]))
+ assert refined_request.temperature == pytest.approx(
+ float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"])
+ )
assert refined_request.top_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_P"]))
assert refined_request.top_k == int(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_K"])
assert refined_request.min_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_MIN_P"]))
diff --git a/tests/test_prompt_cache_cancellation.py b/tests/test_prompt_cache_cancellation.py
index c214580a..1dabbb63 100644
--- a/tests/test_prompt_cache_cancellation.py
+++ b/tests/test_prompt_cache_cancellation.py
@@ -87,7 +87,9 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]:
chunk.token = 4
yield chunk
- mock_inference_worker.submit_stream = Mock(side_effect=lambda *args, **kwargs: mock_response_gen())
+ mock_inference_worker.submit_stream = Mock(
+ side_effect=lambda *args, **kwargs: mock_response_gen()
+ )
mock_prompt_cache = Mock()
mock_prompt_cache.fetch_nearest_cache.return_value = (Mock(name="cache_obj"), [])
@@ -115,7 +117,7 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]:
mock_parsers_result.is_unified = False
mock_parsers_result.reasoning_parser = None
mock_parsers_result.tool_parser = None
- with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result):
+ with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result):
fake_request = Mock()
gen = handler.generate_text_stream(fake_request)
try:
@@ -158,7 +160,9 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]:
chunk.generation_tokens = 20
yield chunk
- mock_inference_worker.submit_stream = Mock(side_effect=lambda *args, **kwargs: mock_response_gen())
+ mock_inference_worker.submit_stream = Mock(
+ side_effect=lambda *args, **kwargs: mock_response_gen()
+ )
mock_prompt_cache = Mock()
mock_prompt_cache.fetch_nearest_cache.return_value = (Mock(name="cache_obj"), [])
@@ -184,7 +188,7 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]:
mock_parsers_result.is_unified = False
mock_parsers_result.reasoning_parser = None
mock_parsers_result.tool_parser = None
- with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result):
+ with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result):
fake_request = Mock()
gen = handler.generate_text_stream(fake_request)
try:
@@ -246,7 +250,7 @@ async def blocked_response_gen() -> AsyncGenerator[None, None]:
mock_parsers_result.is_unified = False
mock_parsers_result.reasoning_parser = None
mock_parsers_result.tool_parser = None
- with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result):
+ with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result):
fake_request = Mock()
gen = handler.generate_text_stream(fake_request)
task = asyncio.create_task(gen.__anext__())
@@ -284,7 +288,9 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]:
chunk.token = token
yield chunk
- mock_inference_worker.submit_stream = Mock(side_effect=lambda *args, **kwargs: mock_response_gen())
+ mock_inference_worker.submit_stream = Mock(
+ side_effect=lambda *args, **kwargs: mock_response_gen()
+ )
mock_prompt_cache = Mock()
mock_prompt_cache.fetch_nearest_cache.return_value = (Mock(name="cache_obj"), [])
@@ -310,7 +316,7 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]:
mock_parsers_result.is_unified = False
mock_parsers_result.reasoning_parser = None
mock_parsers_result.tool_parser = None
- with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result):
+ with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result):
fake_request = Mock()
gen = handler.generate_text_stream(fake_request)
try:
diff --git a/tests/test_responses_request_conversion.py b/tests/test_responses_request_conversion.py
index 945819d9..a6fecf0a 100644
--- a/tests/test_responses_request_conversion.py
+++ b/tests/test_responses_request_conversion.py
@@ -133,7 +133,7 @@ def test_convert_responses_stream_history_skips_reasoning_but_keeps_tool_turns()
"status": "completed",
"call_id": "call_123",
"name": "du",
- "arguments": "{\"path\":\"/tmp\"}",
+ "arguments": '{"path":"/tmp"}',
},
{
"type": "function_call_output",
@@ -180,10 +180,12 @@ def test_convert_responses_stream_history_skips_reasoning_but_keeps_tool_turns()
"user",
]
- tool_call_messages = [m for m in chat_request.messages if m.role == "assistant" and m.tool_calls]
+ tool_call_messages = [
+ m for m in chat_request.messages if m.role == "assistant" and m.tool_calls
+ ]
assert len(tool_call_messages) == 1
assert tool_call_messages[0].tool_calls[0].function.name == "du"
- assert tool_call_messages[0].tool_calls[0].function.arguments == "{\"path\":\"/tmp\"}"
+ assert tool_call_messages[0].tool_calls[0].function.arguments == '{"path":"/tmp"}'
assert tool_call_messages[0].tool_calls[0].id == "call_123"
tool_output_messages = [m for m in chat_request.messages if m.role == "tool"]