diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a6a7bc65..cd40bb4c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -29,4 +29,4 @@ jobs: env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} - run: twine upload dist/* \ No newline at end of file + run: twine upload dist/* diff --git a/Makefile b/Makefile index 0398c325..519efd24 100644 --- a/Makefile +++ b/Makefile @@ -7,4 +7,4 @@ run: --queue-size 100 install: - pip install -e . \ No newline at end of file + pip install -e . diff --git a/app/__init__.py b/app/__init__.py index 0b90d88e..6e5d4622 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,7 +1,8 @@ import os + from .version import __version__ # Suppress transformers warnings -os.environ['TRANSFORMERS_VERBOSITY'] = 'error' +os.environ["TRANSFORMERS_VERBOSITY"] = "error" -__all__ = ["__version__"] \ No newline at end of file +__all__ = ["__version__"] diff --git a/app/api/endpoints.py b/app/api/endpoints.py index 91d0d4c7..fc247a84 100644 --- a/app/api/endpoints.py +++ b/app/api/endpoints.py @@ -419,6 +419,7 @@ async def chat_completions( content=create_error_response(str(e)), status_code=HTTPStatus.INTERNAL_SERVER_ERROR ) + @router.post("/v1/embeddings", response_model=None) async def embeddings( request: EmbeddingRequest, raw_request: Request @@ -855,7 +856,6 @@ async def process_multimodal_request( return JSONResponse(content=final_response.model_dump(exclude_none=True)) - async def process_text_request( handler: MLXLMHandler | MLXVLMHandler, request: ChatCompletionRequest, @@ -990,10 +990,12 @@ def format_final_response( request_id=request_id, ) + # ============================================================================= # Responses API Handlers # ============================================================================= + def _normalize_responses_item(item: Any) -> dict[str, Any]: """Normalize TypedDict/BaseModel response item to a plain dictionary.""" if isinstance(item, dict): @@ -1014,7 +1016,9 @@ def _serialize_responses_tool_output(output: Any) -> str: text_parts: list[str] = [] for output_item in output: normalized = _normalize_responses_item(output_item) - if normalized.get("type") in {"input_text", "output_text", "text"} and normalized.get("text"): + if normalized.get("type") in {"input_text", "output_text", "text"} and normalized.get( + "text" + ): text_parts.append(str(normalized["text"])) if text_parts: return "\n".join(text_parts) @@ -1104,9 +1108,7 @@ def _convert_responses_tool_choice(tool_choice: Any) -> Any: return "auto" -def convert_responses_request_to_chat_request( - request: ResponsesRequest -) -> ChatCompletionRequest: +def convert_responses_request_to_chat_request(request: ResponsesRequest) -> ChatCompletionRequest: """Convert a Responses request into a ChatCompletionRequest with full turn history.""" chat_messages: list[Message] = [] pending_tool_calls: list[ChatCompletionMessageToolCall] = [] @@ -1188,7 +1190,9 @@ def flush_pending_tool_calls() -> None: if item_type in {"input_text", "text"}: text = item.get("text") if text: - pending_user_parts.append(ChatCompletionContentPartText(type="text", text=str(text))) + pending_user_parts.append( + ChatCompletionContentPartText(type="text", text=str(text)) + ) elif item_type in {"input_image", "image_url"} and item.get("image_url"): pending_user_parts.append( ChatCompletionContentPartImage( @@ -1253,10 +1257,9 @@ def flush_pending_tool_calls() -> None: return ChatCompletionRequest(**chat_request_payload) + def format_final_responses_response( - response: str | dict[str, Any], - request: ResponsesRequest, - usage: UsageInfo | None = None + response: str | dict[str, Any], request: ResponsesRequest, usage: UsageInfo | None = None ) -> ResponsesResponse: """Format the final non-streaming response.""" response_payload: dict[str, Any] @@ -1391,6 +1394,7 @@ def refine_responses_request( request.model = Config.TEXT_MODEL return request + async def handle_responses_stream_response( generator: AsyncGenerator[Any, None], request: ResponsesRequest, @@ -1438,9 +1442,7 @@ def _create_base_response( text_val = None if request.text: text_val = ( - request.text.model_dump() - if hasattr(request.text, "model_dump") - else request.text + request.text.model_dump() if hasattr(request.text, "model_dump") else request.text ) return { "id": resp_id, @@ -1666,7 +1668,12 @@ def _open_message_item() -> str: { "id": msg_item_id, "content": [ - {"annotations": [], "text": full_text, "type": "output_text", "logprobs": None} + { + "annotations": [], + "text": full_text, + "type": "output_text", + "logprobs": None, + } ], "role": "assistant", "status": "completed", @@ -1707,9 +1714,9 @@ def _open_message_item() -> str: f"data: {json.dumps({'response': final_response_obj, 'sequence_number': _next_seq(), 'type': 'response.completed'})}\n\n" ) + async def process_text_responses_request( - handler: MLXLMHandler, - request: ResponsesRequest + handler: MLXLMHandler, request: ResponsesRequest ) -> ResponsesResponse | StreamingResponse | JSONResponse: """Handle text-only Responses API requests.""" refined_request = refine_responses_request(request, handler) @@ -1764,13 +1771,10 @@ async def process_multimodal_responses_request( result = await handler.generate_multimodal_response(chat_request) response_data = result.get("response") usage = result.get("usage") - final_response = format_final_responses_response( - response_data, - refined_request, - usage - ) + final_response = format_final_responses_response(response_data, refined_request, usage) return JSONResponse(content=final_response.model_dump(exclude_none=True)) + @router.post("/v1/responses", response_model=None) async def responses_endpoint( request: ResponsesRequest, raw_request: Request diff --git a/app/core/audio_processor.py b/app/core/audio_processor.py index 814fb292..f871987d 100644 --- a/app/core/audio_processor.py +++ b/app/core/audio_processor.py @@ -1,17 +1,17 @@ -import os -import gc import asyncio -from typing import List +import gc +import os + from .base_processor import BaseProcessor class AudioProcessor(BaseProcessor): """Audio processor for handling audio files with caching and validation.""" - + def __init__(self, max_workers: int = 4, cache_size: int = 1000): super().__init__(max_workers, cache_size) # Supported audio formats - self._supported_formats = {'.mp3', '.wav'} + self._supported_formats = {".mp3", ".wav"} def _get_media_format(self, media_url: str, data: bytes = None) -> str: """Determine audio format from URL or data.""" @@ -20,22 +20,22 @@ def _get_media_format(self, media_url: str, data: bytes = None) -> str: mime_type = media_url.split(";")[0].split(":")[1] if "mp3" in mime_type or "mpeg" in mime_type: return "mp3" - elif "wav" in mime_type: + if "wav" in mime_type: return "wav" - elif "m4a" in mime_type or "mp4" in mime_type: + if "m4a" in mime_type or "mp4" in mime_type: return "m4a" - elif "ogg" in mime_type: + if "ogg" in mime_type: return "ogg" - elif "flac" in mime_type: + if "flac" in mime_type: return "flac" - elif "aac" in mime_type: + if "aac" in mime_type: return "aac" else: # Extract format from file extension ext = os.path.splitext(media_url.lower())[1] if ext in self._supported_formats: return ext[1:] # Remove the dot - + # Default to mp3 if format cannot be determined return "mp3" @@ -43,27 +43,27 @@ def _validate_media_data(self, data: bytes) -> bool: """Basic validation of audio data.""" if len(data) < 100: # Too small to be a valid audio file return False - + # Check for common audio file signatures audio_signatures = [ - b'ID3', # MP3 with ID3 tag - b'\xff\xfb', # MP3 frame header - b'\xff\xf3', # MP3 frame header - b'\xff\xf2', # MP3 frame header - b'RIFF', # WAV/AVI - b'OggS', # OGG - b'fLaC', # FLAC - b'\x00\x00\x00\x20ftypM4A', # M4A + b"ID3", # MP3 with ID3 tag + b"\xff\xfb", # MP3 frame header + b"\xff\xf3", # MP3 frame header + b"\xff\xf2", # MP3 frame header + b"RIFF", # WAV/AVI + b"OggS", # OGG + b"fLaC", # FLAC + b"\x00\x00\x00\x20ftypM4A", # M4A ] - + for sig in audio_signatures: if data.startswith(sig): return True - + # Check for WAV format (RIFF header might be at different position) - if b'WAVE' in data[:50]: + if b"WAVE" in data[:50]: return True - + return True # Allow unknown formats to pass through def _get_timeout(self) -> int: @@ -76,7 +76,7 @@ def _get_max_file_size(self) -> int: def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: """Process audio data and save to cached path.""" - with open(cached_path, 'wb') as f: + with open(cached_path, "wb") as f: f.write(data) self._cleanup_old_files() return cached_path @@ -89,10 +89,10 @@ async def process_audio_url(self, audio_url: str) -> str: """Process a single audio URL and return path to cached file.""" return await self._process_single_media(audio_url) - async def process_audio_urls(self, audio_urls: List[str]) -> List[str]: + async def process_audio_urls(self, audio_urls: list[str]) -> list[str]: """Process multiple audio URLs and return paths to cached files.""" tasks = [self.process_audio_url(url) for url in audio_urls] results = await asyncio.gather(*tasks, return_exceptions=True) # Force garbage collection after batch processing gc.collect() - return results \ No newline at end of file + return results diff --git a/app/core/base_processor.py b/app/core/base_processor.py index 52833fcf..74b75c8c 100644 --- a/app/core/base_processor.py +++ b/app/core/base_processor.py @@ -1,30 +1,31 @@ +from abc import ABC, abstractmethod import base64 +from concurrent.futures import ThreadPoolExecutor +import gc import hashlib import os import tempfile -import aiohttp import time -import gc +from typing import Any + +import aiohttp from loguru import logger -from typing import Dict, Optional, Any -from concurrent.futures import ThreadPoolExecutor -from abc import ABC, abstractmethod class BaseProcessor(ABC): """Base class for media processors with common caching and session management.""" - + def __init__(self, max_workers: int = 4, cache_size: int = 1000): # Use tempfile for macOS-efficient temporary file handling self.temp_dir = tempfile.TemporaryDirectory() - self._session: Optional[aiohttp.ClientSession] = None + self._session: aiohttp.ClientSession | None = None self.executor = ThreadPoolExecutor(max_workers=max_workers) self._cache_size = cache_size self._last_cleanup = time.time() self._cleanup_interval = 3600 # 1 hour # Replace lru_cache with manual cache for better control - self._hash_cache: Dict[str, str] = {} - self._cache_access_times: Dict[str, float] = {} + self._hash_cache: dict[str, str] = {} + self._cache_access_times: dict[str, float] = {} def _get_media_hash(self, media_url: str) -> str: """Get hash for media URL with manual caching that can be cleared.""" @@ -32,20 +33,20 @@ def _get_media_hash(self, media_url: str) -> str: if media_url in self._hash_cache: self._cache_access_times[media_url] = time.time() return self._hash_cache[media_url] - + # Generate hash if media_url.startswith("data:"): _, encoded = media_url.split(",", 1) data = base64.b64decode(encoded) else: - data = media_url.encode('utf-8') - + data = media_url.encode("utf-8") + hash_value = hashlib.md5(data).hexdigest() - + # Add to cache with size management if len(self._hash_cache) >= self._cache_size: self._evict_oldest_cache_entries() - + self._hash_cache[media_url] = hash_value self._cache_access_times[media_url] = time.time() return hash_value @@ -54,53 +55,47 @@ def _evict_oldest_cache_entries(self): """Remove oldest 20% of cache entries to make room.""" if not self._cache_access_times: return - + # Sort by access time and remove oldest 20% sorted_items = sorted(self._cache_access_times.items(), key=lambda x: x[1]) to_remove = len(sorted_items) // 5 # Remove 20% - + for url, _ in sorted_items[:to_remove]: self._hash_cache.pop(url, None) self._cache_access_times.pop(url, None) - + # Force garbage collection after cache eviction gc.collect() @abstractmethod def _get_media_format(self, media_url: str, data: bytes = None) -> str: """Determine media format from URL or data. Must be implemented by subclasses.""" - pass @abstractmethod def _validate_media_data(self, data: bytes) -> bool: """Validate media data. Must be implemented by subclasses.""" - pass @abstractmethod def _get_timeout(self) -> int: """Get timeout for HTTP requests. Must be implemented by subclasses.""" - pass @abstractmethod def _get_max_file_size(self) -> int: """Get maximum file size in bytes. Must be implemented by subclasses.""" - pass @abstractmethod - def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> Dict[str, Any]: + def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> dict[str, Any]: """Process media data and save to cached path. Must be implemented by subclasses.""" - pass @abstractmethod def _get_media_type_name(self) -> str: """Get media type name for logging. Must be implemented by subclasses.""" - pass async def _get_session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self._get_timeout()), - headers={"User-Agent": "mlx-server-OAI-compat/1.0"} + headers={"User-Agent": "mlx-server-OAI-compat/1.0"}, ) return self._session @@ -118,7 +113,7 @@ def _cleanup_old_files(self): self._evict_oldest_cache_entries() gc.collect() # Force garbage collection after cleanup except Exception as e: - logger.warning(f"Failed to clean up old {self._get_media_type_name()} files: {str(e)}") + logger.warning(f"Failed to clean up old {self._get_media_type_name()} files: {e!s}") async def _process_single_media(self, media_url: str, **kwargs) -> str: try: @@ -132,39 +127,40 @@ async def _process_single_media(self, media_url: str, **kwargs) -> str: if os.path.exists(media_url): # Copy local file to cache - with open(media_url, 'rb') as f: + with open(media_url, "rb") as f: data = f.read() - + if not self._validate_media_data(data): raise ValueError(f"Invalid {self._get_media_type_name()} file format") - + return self._process_media_data(data, cached_path, **kwargs) - elif media_url.startswith("data:"): + if media_url.startswith("data:"): _, encoded = media_url.split(",", 1) estimated_size = len(encoded) * 3 / 4 if estimated_size > self._get_max_file_size(): - raise ValueError(f"Base64-encoded {self._get_media_type_name()} exceeds size limit") + raise ValueError( + f"Base64-encoded {self._get_media_type_name()} exceeds size limit" + ) data = base64.b64decode(encoded) - + if not self._validate_media_data(data): raise ValueError(f"Invalid {self._get_media_type_name()} file format") - + + return self._process_media_data(data, cached_path, **kwargs) + session = await self._get_session() + async with session.get(media_url) as response: + response.raise_for_status() + data = await response.read() + + if not self._validate_media_data(data): + raise ValueError(f"Invalid {self._get_media_type_name()} file format") + return self._process_media_data(data, cached_path, **kwargs) - else: - session = await self._get_session() - async with session.get(media_url) as response: - response.raise_for_status() - data = await response.read() - - if not self._validate_media_data(data): - raise ValueError(f"Invalid {self._get_media_type_name()} file format") - - return self._process_media_data(data, cached_path, **kwargs) except Exception as e: - logger.error(f"Failed to process {self._get_media_type_name()}: {str(e)}") - raise ValueError(f"Failed to process {self._get_media_type_name()}: {str(e)}") + logger.error(f"Failed to process {self._get_media_type_name()}: {e!s}") + raise ValueError(f"Failed to process {self._get_media_type_name()}: {e!s}") finally: gc.collect() @@ -175,25 +171,25 @@ def clear_cache(self): gc.collect() async def cleanup(self): - if hasattr(self, '_cleaned') and self._cleaned: + if hasattr(self, "_cleaned") and self._cleaned: return self._cleaned = True try: # Clear caches before cleanup self.clear_cache() - + if self._session and not self._session.closed: await self._session.close() except Exception as e: - logger.warning(f"Exception closing aiohttp session: {str(e)}") + logger.warning(f"Exception closing aiohttp session: {e!s}") try: self.executor.shutdown(wait=True) except Exception as e: - logger.warning(f"Exception shutting down executor: {str(e)}") + logger.warning(f"Exception shutting down executor: {e!s}") try: self.temp_dir.cleanup() except Exception as e: - logger.warning(f"Exception cleaning up temp directory: {str(e)}") + logger.warning(f"Exception cleaning up temp directory: {e!s}") async def __aenter__(self): return self @@ -204,4 +200,4 @@ async def __aexit__(self, exc_type, exc, tb): def __del__(self): # Async cleanup cannot be reliably performed in __del__ # Please use 'async with Processor()' or call 'await cleanup()' explicitly. - pass \ No newline at end of file + pass diff --git a/app/core/handler_process.py b/app/core/handler_process.py index b38dd3fa..1898d5da 100644 --- a/app/core/handler_process.py +++ b/app/core/handler_process.py @@ -439,10 +439,7 @@ async def start(self, queue_config: dict[str, Any]) -> None: name=f"handler-{self.model_id}", ) self._process.start() - logger.info( - f"Spawned handler process for '{self.model_id}' " - f"(pid={self._process.pid})" - ) + logger.info(f"Spawned handler process for '{self.model_id}' (pid={self._process.pid})") # Wait for the ready signal. ready_queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() @@ -503,14 +500,10 @@ def _response_reader(self) -> None: pending = self._pending.get(req_id) if pending and self._loop: try: - future = asyncio.run_coroutine_threadsafe( - pending.put(response), self._loop - ) + future = asyncio.run_coroutine_threadsafe(pending.put(response), self._loop) future.result(timeout=60) except concurrent.futures.TimeoutError: - logger.warning( - f"Timeout delivering stream chunk for {req_id}" - ) + logger.warning(f"Timeout delivering stream chunk for {req_id}") except Exception: if self._running: logger.debug( diff --git a/app/core/image_processor.py b/app/core/image_processor.py index c38ccc65..a941a675 100644 --- a/app/core/image_processor.py +++ b/app/core/image_processor.py @@ -1,15 +1,16 @@ -import gc import asyncio -from PIL import Image -from loguru import logger +import gc from io import BytesIO -from typing import List + +from loguru import logger +from PIL import Image + from .base_processor import BaseProcessor class ImageProcessor(BaseProcessor): """Image processor for handling image files with caching, validation, and processing.""" - + def __init__(self, max_workers: int = 4, cache_size: int = 1000): super().__init__(max_workers, cache_size) Image.MAX_IMAGE_PIXELS = 100000000 # Limit to 100 megapixels @@ -23,27 +24,27 @@ def _validate_media_data(self, data: bytes) -> bool: """Basic validation of image data.""" if len(data) < 100: # Too small to be a valid image file return False - + # Check for common image file signatures image_signatures = [ - b'\xff\xd8\xff', # JPEG - b'\x89PNG\r\n\x1a\n', # PNG - b'GIF87a', # GIF87a - b'GIF89a', # GIF89a - b'BM', # BMP - b'II*\x00', # TIFF (little endian) - b'MM\x00*', # TIFF (big endian) - b'RIFF', # WebP (part of RIFF) + b"\xff\xd8\xff", # JPEG + b"\x89PNG\r\n\x1a\n", # PNG + b"GIF87a", # GIF87a + b"GIF89a", # GIF89a + b"BM", # BMP + b"II*\x00", # TIFF (little endian) + b"MM\x00*", # TIFF (big endian) + b"RIFF", # WebP (part of RIFF) ] - + for sig in image_signatures: if data.startswith(sig): return True - + # Additional check for WebP - if data.startswith(b'RIFF') and b'WEBP' in data[:20]: + if data.startswith(b"RIFF") and b"WEBP" in data[:20]: return True - + return False def _get_timeout(self) -> int: @@ -58,7 +59,9 @@ def _get_media_type_name(self) -> str: """Get media type name for logging.""" return "image" - def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 448) -> Image.Image: + def _resize_image_keep_aspect_ratio( + self, image: Image.Image, max_size: int = 448 + ) -> Image.Image: width, height = image.size if width <= max_size and height <= max_size: return image @@ -75,15 +78,15 @@ def _resize_image_keep_aspect_ratio(self, image: Image.Image, max_size: int = 44 return image def _prepare_image_for_saving(self, image: Image.Image) -> Image.Image: - if image.mode in ('RGBA', 'LA'): - background = Image.new('RGB', image.size, (255, 255, 255)) - if image.mode == 'RGBA': + if image.mode in ("RGBA", "LA"): + background = Image.new("RGB", image.size, (255, 255, 255)) + if image.mode == "RGBA": background.paste(image, mask=image.split()[3]) else: background.paste(image, mask=image.split()[1]) return background - elif image.mode != 'RGB': - return image.convert('RGB') + if image.mode != "RGB": + return image.convert("RGB") return image def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: @@ -91,12 +94,12 @@ def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: image = None resize = kwargs.get("resize", True) try: - with Image.open(BytesIO(data), mode='r') as image: + with Image.open(BytesIO(data), mode="r") as image: if resize: image = self._resize_image_keep_aspect_ratio(image) image = self._prepare_image_for_saving(image) - image.save(cached_path, 'PNG', quality=100, optimize=True) - + image.save(cached_path, "PNG", quality=100, optimize=True) + self._cleanup_old_files() return cached_path finally: @@ -111,10 +114,10 @@ async def process_image_url(self, image_url: str, resize: bool = True) -> str: """Process a single image URL and return path to cached file.""" return await self._process_single_media(image_url, resize=resize) - async def process_image_urls(self, image_urls: List[str], resize: bool = True) -> List[str]: + async def process_image_urls(self, image_urls: list[str], resize: bool = True) -> list[str]: """Process multiple image URLs and return paths to cached files.""" tasks = [self.process_image_url(url, resize=resize) for url in image_urls] results = await asyncio.gather(*tasks, return_exceptions=True) # Force garbage collection after batch processing gc.collect() - return results \ No newline at end of file + return results diff --git a/app/core/inference_worker.py b/app/core/inference_worker.py index c76e54cd..31563766 100644 --- a/app/core/inference_worker.py +++ b/app/core/inference_worker.py @@ -3,11 +3,11 @@ from __future__ import annotations import asyncio +from collections.abc import AsyncGenerator, Callable, Generator import queue import threading -from collections.abc import AsyncGenerator, Generator from threading import Thread -from typing import Any, Callable +from typing import Any from loguru import logger @@ -85,9 +85,7 @@ def _record(self, success: bool) -> None: else: self._failed += 1 - async def submit( - self, func: Callable[..., Any], *args: Any, **kwargs: Any - ) -> Any: + async def submit(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Run func on the worker thread; await its result. Raises QueueFull, TimeoutError, or func's exception.""" loop = asyncio.get_running_loop() future: asyncio.Future[Any] = loop.create_future() @@ -107,7 +105,7 @@ def _work() -> None: raise asyncio.QueueFull("Inference queue is full") try: return await asyncio.wait_for(future, timeout=self._timeout) - except asyncio.TimeoutError: + except TimeoutError: raise TimeoutError(f"Inference timed out after {self._timeout}s") def submit_stream( diff --git a/app/core/model_registry.py b/app/core/model_registry.py index 590f3726..005e18ec 100644 --- a/app/core/model_registry.py +++ b/app/core/model_registry.py @@ -82,8 +82,7 @@ async def register_model( self._metadata[model_id] = metadata logger.info( - f"Registered model: {model_id} (type={model_type}, " - f"context_length={context_length})" + f"Registered model: {model_id} (type={model_type}, context_length={context_length})" ) def get_handler(self, model_id: str) -> Any: @@ -107,8 +106,7 @@ def get_handler(self, model_id: str) -> Any: if model_id not in self._handlers: available = ", ".join(sorted(self._handlers.keys())) or "(none)" raise KeyError( - f"Model '{model_id}' not found in registry. " - f"Available models: {available}" + f"Model '{model_id}' not found in registry. Available models: {available}" ) return self._handlers[model_id] @@ -203,9 +201,7 @@ async def cleanup_all(self) -> None: logger.info("All models unregistered and cleaned up") @staticmethod - async def _cleanup_single_handler( - model_id: str, handler: Any - ) -> None: + async def _cleanup_single_handler(model_id: str, handler: Any) -> None: """Clean up a single handler, logging success or failure. Parameters diff --git a/app/core/video_processor.py b/app/core/video_processor.py index 92d7141d..0123ea41 100644 --- a/app/core/video_processor.py +++ b/app/core/video_processor.py @@ -1,18 +1,19 @@ -import os -import gc import asyncio +import gc +import os + from loguru import logger -from typing import List + from .base_processor import BaseProcessor class VideoProcessor(BaseProcessor): """Video processor for handling video files with caching, validation, and processing.""" - + def __init__(self, max_workers: int = 4, cache_size: int = 1000): super().__init__(max_workers, cache_size) # Supported video formats - self._supported_formats = {'.mp4', '.avi', '.mov'} + self._supported_formats = {".mp4", ".avi", ".mov"} def _get_media_format(self, media_url: str, data: bytes = None) -> str: """Determine video format from URL or data.""" @@ -21,16 +22,16 @@ def _get_media_format(self, media_url: str, data: bytes = None) -> str: mime_type = media_url.split(";")[0].split(":")[1] if "mp4" in mime_type: return "mp4" - elif "quicktime" in mime_type or "mov" in mime_type: + if "quicktime" in mime_type or "mov" in mime_type: return "mov" - elif "x-msvideo" in mime_type or "avi" in mime_type: + if "x-msvideo" in mime_type or "avi" in mime_type: return "avi" else: # Extract format from file extension ext = os.path.splitext(media_url.lower())[1] if ext in self._supported_formats: return ext[1:] # Remove the dot - + # Default to mp4 if format cannot be determined return "mp4" @@ -38,50 +39,45 @@ def _validate_media_data(self, data: bytes) -> bool: """Basic validation of video data.""" if len(data) < 100: # Too small to be a valid video file return False - + # Check for common video file signatures video_signatures = [ # MP4/M4V/MOV (ISO Base Media File Format) - (b'\x00\x00\x00\x14ftypisom', 0), # MP4 - (b'\x00\x00\x00\x18ftyp', 0), # MP4/MOV - (b'\x00\x00\x00\x1cftyp', 0), # MP4/MOV - (b'\x00\x00\x00\x20ftyp', 0), # MP4/MOV - (b'ftyp', 4), # MP4/MOV (ftyp at offset 4) - + (b"\x00\x00\x00\x14ftypisom", 0), # MP4 + (b"\x00\x00\x00\x18ftyp", 0), # MP4/MOV + (b"\x00\x00\x00\x1cftyp", 0), # MP4/MOV + (b"\x00\x00\x00\x20ftyp", 0), # MP4/MOV + (b"ftyp", 4), # MP4/MOV (ftyp at offset 4) # AVI - (b'RIFF', 0), # AVI (also check for 'AVI ' at offset 8) - + (b"RIFF", 0), # AVI (also check for 'AVI ' at offset 8) # WebM/MKV (Matroska) - (b'\x1a\x45\xdf\xa3', 0), # Matroska/WebM - + (b"\x1a\x45\xdf\xa3", 0), # Matroska/WebM # FLV - (b'FLV\x01', 0), # Flash Video - + (b"FLV\x01", 0), # Flash Video # MPEG - (b'\x00\x00\x01\xba', 0), # MPEG PS - (b'\x00\x00\x01\xb3', 0), # MPEG PS - + (b"\x00\x00\x01\xba", 0), # MPEG PS + (b"\x00\x00\x01\xb3", 0), # MPEG PS # QuickTime - (b'moov', 0), # QuickTime - (b'mdat', 0), # QuickTime + (b"moov", 0), # QuickTime + (b"mdat", 0), # QuickTime ] - + for sig, offset in video_signatures: if len(data) > offset + len(sig): - if data[offset:offset+len(sig)] == sig: + if data[offset : offset + len(sig)] == sig: # Additional validation for AVI - if sig == b'RIFF' and len(data) > 12: - if data[8:12] == b'AVI ': + if sig == b"RIFF" and len(data) > 12: + if data[8:12] == b"AVI ": return True - elif sig == b'RIFF': + elif sig == b"RIFF": continue # Not AVI, might be WAV else: return True - + # Check for ftyp box anywhere in first 32 bytes (MP4/MOV) - if b'ftyp' in data[:32]: + if b"ftyp" in data[:32]: return True - + # Allow unknown formats to pass through for flexibility return True @@ -96,14 +92,14 @@ def _get_max_file_size(self) -> int: def _process_media_data(self, data: bytes, cached_path: str, **kwargs) -> str: """Process video data and save to cached path.""" try: - with open(cached_path, 'wb') as f: + with open(cached_path, "wb") as f: f.write(data) - + logger.info(f"Saved video to {cached_path} ({len(data)} bytes)") self._cleanup_old_files() return cached_path except Exception as e: - logger.error(f"Failed to save video data: {str(e)}") + logger.error(f"Failed to save video data: {e!s}") raise def _get_media_type_name(self) -> str: @@ -113,27 +109,27 @@ def _get_media_type_name(self) -> str: async def process_video_url(self, video_url: str) -> str: """ Process a single video URL and return path to cached file. - + Supports: - HTTP/HTTPS URLs (downloads video) - Local file paths (copies to cache) - Data URLs (base64 encoded videos) - + Args: video_url: URL, file path, or data URL of the video - + Returns: Path to the cached video file in temp directory """ return await self._process_single_media(video_url) - async def process_video_urls(self, video_urls: List[str]) -> List[str]: + async def process_video_urls(self, video_urls: list[str]) -> list[str]: """ Process multiple video URLs and return paths to cached files. - + Args: video_urls: List of URLs, file paths, or data URLs of videos - + Returns: List of paths to cached video files """ diff --git a/app/handler/mflux.py b/app/handler/mflux.py index c54b770b..ef5abbcb 100644 --- a/app/handler/mflux.py +++ b/app/handler/mflux.py @@ -1,14 +1,14 @@ import asyncio import base64 +import gc +from http import HTTPStatus import io +from io import BytesIO import os import tempfile import time +from typing import Any import uuid -import gc -from io import BytesIO -from http import HTTPStatus -from typing import Any, Dict, List, Optional, Union from fastapi import HTTPException, UploadFile from loguru import logger @@ -35,12 +35,18 @@ class MLXFluxHandler: handler_type: str = "image" - def __init__(self, model_path: str, max_concurrency: int = 1, quantize: Optional[int] = None, - config_name: str = "flux-schnell", lora_paths: Optional[List[str]] = None, - lora_scales: Optional[List[float]] = None): + def __init__( + self, + model_path: str, + max_concurrency: int = 1, + quantize: int | None = None, + config_name: str = "flux-schnell", + lora_paths: list[str] | None = None, + lora_scales: list[float] | None = None, + ): """ Initialize the handler with the specified model path. - + Args: model_path (str): Path to the model directory or model name for Flux. max_concurrency (int): Maximum number of concurrent model inference tasks. @@ -54,40 +60,44 @@ def __init__(self, model_path: str, max_concurrency: int = 1, quantize: Optional self.config_name = config_name self.lora_paths = lora_paths self.lora_scales = lora_scales - + self.model = ImageGenerationModel( - model_path=model_path, + model_path=model_path, quantize=quantize, config_name=config_name, lora_paths=lora_paths, - lora_scales=lora_scales + lora_scales=lora_scales, ) self.model_created = int(time.time()) # Store creation time when model is loaded - + # Dedicated inference thread — keeps the event loop free during # blocking MLX model computation. self.inference_worker = InferenceWorker() - logger.info(f"Initialized MLXFluxHandler with model path: {model_path}, config name: {config_name}") + logger.info( + f"Initialized MLXFluxHandler with model path: {model_path}, config name: {config_name}" + ) if lora_paths: logger.info(f"Using LoRA adapters: {lora_paths} with scales: {lora_scales}") - async def get_models(self) -> List[Dict[str, Any]]: + async def get_models(self) -> list[dict[str, Any]]: """ Get list of available models with their metadata. """ try: - return [{ - "id": self.model_path, - "object": "model", - "created": self.model_created, - "owned_by": "local" - }] + return [ + { + "id": self.model_path, + "object": "model", + "created": self.model_created, + "owned_by": "local", + } + ] except Exception as e: - logger.error(f"Error getting models: {str(e)}") + logger.error(f"Error getting models: {e!s}") return [] - async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> None: + async def initialize(self, queue_config: dict[str, Any] | None = None) -> None: """Initialize the handler and start the inference worker. Parameters @@ -112,29 +122,26 @@ async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> Non def _parse_image_size(self, size: ImageSize) -> tuple[int, int]: """ Parse image size string to width, height tuple. - + Parameters ---------- size : ImageSize Image size enum value (e.g., "1024x1024"). - + Returns ------- tuple[int, int] Width and height as integers. """ - width, height = map(int, size.value.split('x')) + width, height = map(int, size.value.split("x")) return width, height - + def _build_generation_request_data( - self, - request: ImageGenerationRequest, - width: int, - height: int + self, request: ImageGenerationRequest, width: int, height: int ) -> dict[str, Any]: """ Build request data dictionary for image generation. - + Parameters ---------- request : ImageGenerationRequest @@ -143,7 +150,7 @@ def _build_generation_request_data( Image width in pixels. height : int Image height in pixels. - + Returns ------- dict[str, Any] @@ -156,20 +163,22 @@ def _build_generation_request_data( "seed": request.seed, "guidance": request.guidance_scale, "width": width, - "height": height + "height": height, } - - def _build_edit_request_data(self, image_edit_request: ImageEditRequest, temp_file_paths: list[str]) -> dict[str, Any]: + + def _build_edit_request_data( + self, image_edit_request: ImageEditRequest, temp_file_paths: list[str] + ) -> dict[str, Any]: """ Build request data dictionary for image editing. - + Parameters ---------- image_edit_request : ImageEditRequest The image editing request. temp_file_paths : list[str] List of temporary file paths. - + Returns ------- dict[str, Any] @@ -182,18 +191,18 @@ def _build_edit_request_data(self, image_edit_request: ImageEditRequest, temp_fi "steps": image_edit_request.steps, "seed": image_edit_request.seed, "guidance": image_edit_request.guidance_scale, - "image_paths": temp_file_paths + "image_paths": temp_file_paths, } - + def _create_image_response(self, image_result: Image.Image) -> ImageGenerationResponse: """ Create image generation response from PIL Image. - + Parameters ---------- image_result : Image.Image The generated PIL Image. - + Returns ------- ImageGenerationResponse @@ -201,14 +210,13 @@ def _create_image_response(self, image_result: Image.Image) -> ImageGenerationRe """ image_data_b64 = self._image_to_base64(image_result) return ImageGenerationResponse( - created=int(time.time()), - data=[ImageData(b64_json=image_data_b64)] + created=int(time.time()), data=[ImageData(b64_json=image_data_b64)] ) def _create_edit_response(self, image_result: Image.Image) -> ImageEditResponse: """ Create image editing response from PIL Image. - + Parameters ---------- image_result : Image.Image @@ -216,19 +224,18 @@ def _create_edit_response(self, image_result: Image.Image) -> ImageEditResponse: """ image_data_b64 = self._image_to_base64(image_result) return ImageEditResponse( - created=int(time.time()), - data=[ImageData(b64_json=image_data_b64)] + created=int(time.time()), data=[ImageData(b64_json=image_data_b64)] ) - + def _handle_queue_full_error(self, request_id: str) -> None: """ Handle queue capacity errors. - + Parameters ---------- request_id : str The request ID for logging. - + Raises ------ HTTPException @@ -236,92 +243,90 @@ def _handle_queue_full_error(self, request_id: str) -> None: """ logger.error(f"Queue at capacity for request {request_id}") content = create_error_response( - "Too many requests. Service is at capacity.", - "rate_limit_exceeded", - HTTPStatus.TOO_MANY_REQUESTS + "Too many requests. Service is at capacity.", + "rate_limit_exceeded", + HTTPStatus.TOO_MANY_REQUESTS, ) raise HTTPException(status_code=429, detail=content) - + def _handle_generation_error(self, request_id: str, error: Exception) -> None: """ Handle general generation errors. - + Parameters ---------- request_id : str The request ID for logging. error : Exception The exception that occurred. - + Raises ------ HTTPException 500 error with error details. """ - logger.error(f"Error in image generation for request {request_id}: {str(error)}") + logger.error(f"Error in image generation for request {request_id}: {error!s}") content = create_error_response( - f"Failed to generate image: {str(error)}", - "server_error", - HTTPStatus.INTERNAL_SERVER_ERROR + f"Failed to generate image: {error!s}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR ) raise HTTPException(status_code=500, detail=content) def _handle_edit_error(self, request_id: str, error: Exception) -> None: """ Handle general editing errors. - + Parameters ---------- request_id : str The request ID for logging. error : Exception The exception that occurred. - + Raises ------ HTTPException 500 error with error details. """ - logger.error(f"Error in image editing for request {request_id}: {str(error)}") + logger.error(f"Error in image editing for request {request_id}: {error!s}") content = create_error_response( - f"Failed to edit image: {str(error)}", - "server_error", - HTTPStatus.INTERNAL_SERVER_ERROR + f"Failed to edit image: {error!s}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR ) raise HTTPException(status_code=500, detail=content) - + def _validate_image_file(self, image: UploadFile, idx: int) -> None: """ Validate image file type and size. - + Parameters ---------- image : UploadFile The uploaded image file to validate. idx : int Index of the image (for error messages). - + Raises ------ HTTPException If validation fails. """ - if not image.content_type or image.content_type not in ["image/png", "image/jpeg", "image/jpg"]: + if not image.content_type or image.content_type not in [ + "image/png", + "image/jpeg", + "image/jpg", + ]: raise HTTPException( - status_code=400, - detail=f"Image {idx + 1} must be a PNG, JPEG, or JPG file" + status_code=400, detail=f"Image {idx + 1} must be a PNG, JPEG, or JPG file" ) - - if hasattr(image, 'size') and image.size and image.size > 10 * 1024 * 1024: + + if hasattr(image, "size") and image.size and image.size > 10 * 1024 * 1024: raise HTTPException( - status_code=400, - detail=f"Image {idx + 1} file size must be less than 10MB" + status_code=400, detail=f"Image {idx + 1} file size must be less than 10MB" ) - + async def _upload_to_temp_file(self, image: UploadFile, idx: int, request_id: str) -> str: """ Read, process, and save uploaded image to a temporary file. - + Parameters ---------- image : UploadFile @@ -330,12 +335,12 @@ async def _upload_to_temp_file(self, image: UploadFile, idx: int, request_id: st Index of the image (for error messages and file naming). request_id : str Request ID for file naming. - + Returns ------- str Path to the temporary file. - + Raises ------ HTTPException @@ -345,39 +350,34 @@ async def _upload_to_temp_file(self, image: UploadFile, idx: int, request_id: st image_data = await image.read() if not image_data: raise HTTPException( - status_code=400, - detail=f"Empty image file received for image {idx + 1}" + status_code=400, detail=f"Empty image file received for image {idx + 1}" ) - + # Load and process image try: input_image = Image.open(io.BytesIO(image_data)).convert("RGB") input_image = ImageOps.exif_transpose(input_image) except Exception as img_error: - logger.error(f"Failed to process image {idx + 1}: {str(img_error)}") + logger.error(f"Failed to process image {idx + 1}: {img_error!s}") raise HTTPException( - status_code=400, - detail=f"Invalid or corrupted image file for image {idx + 1}" + status_code=400, detail=f"Invalid or corrupted image file for image {idx + 1}" ) - + # Create and save to temporary file try: temp_file = tempfile.NamedTemporaryFile( - delete=False, - suffix=".png", - prefix=f"edit_{request_id}_{idx + 1}_" + delete=False, suffix=".png", prefix=f"edit_{request_id}_{idx + 1}_" ) temp_file_path = temp_file.name input_image.save(temp_file_path, format="PNG") temp_file.close() return temp_file_path except Exception as temp_error: - logger.error(f"Failed to create temporary file for image {idx + 1}: {str(temp_error)}") + logger.error(f"Failed to create temporary file for image {idx + 1}: {temp_error!s}") raise HTTPException( - status_code=500, - detail=f"Failed to process image {idx + 1} for editing" + status_code=500, detail=f"Failed to process image {idx + 1} for editing" ) - + def _cleanup_temp_files(self, temp_file_paths: list[str]) -> None: """Clean up temporary files and force garbage collection.""" for temp_file_path in temp_file_paths: @@ -386,65 +386,65 @@ def _cleanup_temp_files(self, temp_file_paths: list[str]) -> None: os.unlink(temp_file_path) logger.debug(f"Cleaned up temporary file: {temp_file_path}") except OSError as cleanup_error: - logger.warning(f"Failed to cleanup temporary file {temp_file_path}: {str(cleanup_error)}") + logger.warning( + f"Failed to cleanup temporary file {temp_file_path}: {cleanup_error!s}" + ) gc.collect() async def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: """ Generate an image based on the request parameters. - + Parameters ---------- request : ImageGenerationRequest Request object containing the generation parameters. - + Returns ------- ImageGenerationResponse Response containing the generated image data. - + Raises ------ HTTPException For queue capacity issues or processing failures. """ request_id = f"image-{uuid.uuid4()}" - + try: # Parse image dimensions width, height = 1024, 1024 if request.size: width, height = self._parse_image_size(request.size) - + # Build and submit request to the inference thread request_data = self._build_generation_request_data(request, width, height) - image_result = await self.inference_worker.submit( - self._run_inference, request_data - ) - + image_result = await self.inference_worker.submit(self._run_inference, request_data) + # Create and return response return self._create_image_response(image_result) - + except asyncio.QueueFull: self._handle_queue_full_error(request_id) - + except Exception as e: self._handle_generation_error(request_id, e) async def edit_image(self, image_edit_request: ImageEditRequest) -> ImageEditResponse: """ Edit an image or multiple images based on the request parameters. - + Parameters ---------- image_edit_request : ImageEditRequest Request parameters for image editing. - + Returns ------- ImageEditResponse Response containing the edited image data. - + Raises ------ HTTPException @@ -452,23 +452,22 @@ async def edit_image(self, image_edit_request: ImageEditRequest) -> ImageEditRes """ # Normalize and validate inputs images: list[UploadFile] = ( - image_edit_request.image - if isinstance(image_edit_request.image, list) + image_edit_request.image + if isinstance(image_edit_request.image, list) else [image_edit_request.image] ) - + if not images: raise HTTPException( - status_code=400, - detail="At least one image is required for image editing" + status_code=400, detail="At least one image is required for image editing" ) - + for idx, image in enumerate(images): self._validate_image_file(image, idx) request_id = f"image-edit-{uuid.uuid4()}" temp_file_paths: list[str] = [] - + try: # Process all images to temporary files for idx, image in enumerate(images): @@ -478,31 +477,29 @@ async def edit_image(self, image_edit_request: ImageEditRequest) -> ImageEditRes # Submit request to the inference thread request_data = self._build_edit_request_data(image_edit_request, temp_file_paths) - image_result = await self.inference_worker.submit( - self._run_inference, request_data - ) - + image_result = await self.inference_worker.submit(self._run_inference, request_data) + return self._create_edit_response(image_result) except asyncio.QueueFull: self._handle_queue_full_error(request_id) - + except HTTPException: raise - + except Exception as e: self._handle_edit_error(request_id, e) - + finally: self._cleanup_temp_files(temp_file_paths) - + def _image_to_base64(self, image: Image.Image) -> str: """ Convert PIL Image to base64 string. - + Args: image: PIL Image object. - + Returns: str: Base64 encoded image string. """ @@ -510,9 +507,9 @@ def _image_to_base64(self, image: Image.Image) -> str: image.save(buffer, format="PNG") buffer.seek(0) image_data = buffer.getvalue() - return base64.b64encode(image_data).decode('utf-8') + return base64.b64encode(image_data).decode("utf-8") - def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image: + def _run_inference(self, request_data: dict[str, Any]) -> Image.Image: """Execute image generation/editing on the inference thread. This method is submitted to ``InferenceWorker.submit`` and runs @@ -539,7 +536,7 @@ def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image: image_path = request_data.get("image_path") # For image editing guidance = request_data.get("guidance") image_paths = request_data.get("image_paths", []) - + # Prepare model parameters model_params = { "num_inference_steps": steps, @@ -548,20 +545,22 @@ def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image: "guidance": guidance, "image_paths": image_paths, } - + # Add negative prompt if provided if negative_prompt: model_params["negative_prompt"] = negative_prompt - + # Add image path for image editing if provided if image_path: model_params["image_path"] = image_path - logger.info(f"Processing image edit with prompt: {prompt[:50]}... and image: {image_path}") + logger.info( + f"Processing image edit with prompt: {prompt[:50]}... and image: {image_path}" + ) else: logger.info(f"Generating image with prompt: {prompt[:50]}...") - + # Log all model parameters - logger.info(f"Model inference configurations:") + logger.info("Model inference configurations:") logger.info(f" - Prompt: {prompt[:100]}...") logger.info(f" - Negative prompt: {negative_prompt}") logger.info(f" - Steps: {steps}") @@ -571,22 +570,16 @@ def _run_inference(self, request_data: Dict[str, Any]) -> Image.Image: logger.info(f" - Guidance scale: {guidance}") logger.info(f" - Image path: {image_path}") logger.info(f" - Model params: {model_params}") - + # Generate image - image = self.model( - prompt=prompt, - seed=seed, - **model_params - ) + image = self.model(prompt=prompt, seed=seed, **model_params) return image - + except Exception as e: - logger.error(f"Error processing image generation request: {str(e)}") + logger.error(f"Error processing image generation request: {e!s}") raise - async def edit_image_from_paths( - self, edit_data: Dict[str, Any] - ) -> ImageEditResponse: + async def edit_image_from_paths(self, edit_data: dict[str, Any]) -> ImageEditResponse: """Edit an image from pre-saved file paths. This method is used by ``HandlerProcessProxy`` for IPC: the @@ -620,9 +613,7 @@ async def edit_image_from_paths( "image_paths": temp_file_paths, } - image_result = await self.inference_worker.submit( - self._run_inference, request_data - ) + image_result = await self.inference_worker.submit(self._run_inference, request_data) return self._create_edit_response(image_result) except asyncio.QueueFull: @@ -634,7 +625,7 @@ async def edit_image_from_paths( finally: self._cleanup_temp_files(temp_file_paths) - async def get_queue_stats(self) -> Dict[str, Any]: + async def get_queue_stats(self) -> dict[str, Any]: """Get statistics from the inference worker. Returns @@ -642,24 +633,24 @@ async def get_queue_stats(self) -> Dict[str, Any]: dict[str, Any] Dictionary with worker statistics. """ - if not hasattr(self, 'inference_worker'): + if not hasattr(self, "inference_worker"): return {"error": "Inference worker not initialized"} return self.inference_worker.get_stats() async def cleanup(self) -> None: """Clean up resources and stop the inference worker.""" - if hasattr(self, '_cleaned') and self._cleaned: + if hasattr(self, "_cleaned") and self._cleaned: return self._cleaned = True try: logger.info("Cleaning up MLXFluxHandler resources") - if hasattr(self, 'inference_worker'): + if hasattr(self, "inference_worker"): self.inference_worker.stop() logger.info("Inference worker stopped successfully") except Exception as e: - logger.error(f"Error during MLXFluxHandler cleanup: {str(e)}") + logger.error(f"Error during MLXFluxHandler cleanup: {e!s}") # Force garbage collection gc.collect() @@ -671,11 +662,12 @@ def __del__(self): Note: Async cleanup cannot be reliably performed in __del__. Please use 'await cleanup()' explicitly. """ - if hasattr(self, '_cleaned') and self._cleaned: + if hasattr(self, "_cleaned") and self._cleaned: return # Set flag to prevent multiple cleanup attempts self._cleaned = True + if __name__ == "__main__": handler = MLXFluxHandler(model_path="qwen-image", config_name="qwen-image") - print(handler.model.get_model_info("qwen-image")) \ No newline at end of file + print(handler.model.get_model_info("qwen-image")) diff --git a/app/handler/mlx_embeddings.py b/app/handler/mlx_embeddings.py index dd4d9d14..c0565fe2 100644 --- a/app/handler/mlx_embeddings.py +++ b/app/handler/mlx_embeddings.py @@ -1,8 +1,8 @@ import gc +from http import HTTPStatus import time +from typing import Any import uuid -from http import HTTPStatus -from typing import Any, Dict, List from fastapi import HTTPException from loguru import logger @@ -12,6 +12,7 @@ from ..schemas.openai import EmbeddingRequest from ..utils.errors import create_error_response + class MLXEmbeddingsHandler: """ Handler class for making requests to the underlying MLX embeddings model service. @@ -23,7 +24,7 @@ class MLXEmbeddingsHandler: def __init__(self, model_path: str, max_concurrency: int = 1): """ Initialize the handler with the specified model path. - + Args: model_path (str): Path to the embeddings model to load. max_concurrency (int): Maximum number of concurrent model inference tasks. @@ -37,23 +38,25 @@ def __init__(self, model_path: str, max_concurrency: int = 1): self.inference_worker = InferenceWorker() logger.info(f"Initialized MLXEmbeddingsHandler with model path: {model_path}") - - async def get_models(self) -> List[Dict[str, Any]]: + + async def get_models(self) -> list[dict[str, Any]]: """ Get list of available models with their metadata. """ try: - return [{ - "id": self.model_path, - "object": "model", - "created": self.model_created, - "owned_by": "local" - }] + return [ + { + "id": self.model_path, + "object": "model", + "created": self.model_created, + "owned_by": "local", + } + ] except Exception as e: - logger.error(f"Error getting models: {str(e)}") + logger.error(f"Error getting models: {e!s}") return [] - async def initialize(self, config: Dict[str, Any]) -> None: + async def initialize(self, config: dict[str, Any]) -> None: """Initialize the handler and start the inference worker. Parameters @@ -72,10 +75,10 @@ async def initialize(self, config: Dict[str, Any]) -> None: async def generate_embeddings_response(self, request: EmbeddingRequest): """ Generate embeddings for a given text input. - + Args: request: EmbeddingRequest object containing the text input. - + Returns: List[float]: Embeddings for the input text. """ @@ -88,17 +91,21 @@ async def generate_embeddings_response(self, request: EmbeddingRequest): response = await self.inference_worker.submit( self.model, texts=request.input, - max_length=getattr(request, 'max_length', 512), + max_length=getattr(request, "max_length", 512), ) - + return response except Exception as e: - logger.error(f"Error in embeddings generation: {str(e)}") - content = create_error_response(f"Failed to generate embeddings: {str(e)}", "server_error", HTTPStatus.INTERNAL_SERVER_ERROR) + logger.error(f"Error in embeddings generation: {e!s}") + content = create_error_response( + f"Failed to generate embeddings: {e!s}", + "server_error", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) raise HTTPException(status_code=500, detail=content) - async def get_queue_stats(self) -> Dict[str, Any]: + async def get_queue_stats(self) -> dict[str, Any]: """Get statistics from the inference worker. Returns @@ -118,15 +125,15 @@ async def cleanup(self) -> None: """ try: logger.info("Cleaning up MLXEmbeddingsHandler resources") - if hasattr(self, 'inference_worker'): + if hasattr(self, "inference_worker"): self.inference_worker.stop() - if hasattr(self, 'model'): + if hasattr(self, "model"): self.model.cleanup() # Force garbage collection gc.collect() logger.info("MLXEmbeddingsHandler cleanup completed successfully") except Exception as e: - logger.error(f"Error during MLXEmbeddingsHandler cleanup: {str(e)}") + logger.error(f"Error during MLXEmbeddingsHandler cleanup: {e!s}") raise def __del__(self): @@ -135,7 +142,7 @@ def __del__(self): Note: Async cleanup cannot be reliably performed in __del__. Please use 'await cleanup()' explicitly. """ - if hasattr(self, '_cleaned') and self._cleaned: + if hasattr(self, "_cleaned") and self._cleaned: return # Set flag to prevent multiple cleanup attempts - self._cleaned = True \ No newline at end of file + self._cleaned = True diff --git a/app/handler/mlx_lm.py b/app/handler/mlx_lm.py index e5ff24d8..460049e2 100644 --- a/app/handler/mlx_lm.py +++ b/app/handler/mlx_lm.py @@ -1,6 +1,6 @@ import asyncio from collections.abc import AsyncGenerator -from dataclasses import dataclass, field +from dataclasses import dataclass import gc from http import HTTPStatus import time @@ -223,9 +223,7 @@ def refine_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any] return [{k: v for k, v in message.items() if v is not None} for message in messages] - async def _build_inference_context( - self, request: ChatCompletionRequest - ) -> "_InferenceContext": + async def _build_inference_context(self, request: ChatCompletionRequest) -> "_InferenceContext": """Build the common inference context shared by stream and non-stream paths. Handles: request parsing, message refinement, prompt encoding, KV cache @@ -890,7 +888,9 @@ async def _prepare_text_request( try: # Extract only the fields consumed by MLX_LM.__call__ instead of # serializing the entire Pydantic model with model_dump(). - chat_template_kwargs = request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {} + chat_template_kwargs = ( + request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {} + ) if request.tools: tools = [t.model_dump() for t in request.tools] diff --git a/app/handler/mlx_vlm.py b/app/handler/mlx_vlm.py index 7f2a4cb7..3c287d31 100644 --- a/app/handler/mlx_vlm.py +++ b/app/handler/mlx_vlm.py @@ -213,7 +213,9 @@ async def generate_multimodal_stream(self, request: ChatCompletionRequest): """ try: - input_prompt, model_params, parsers_result = await self._build_inference_context(request) + input_prompt, model_params, parsers_result = await self._build_inference_context( + request + ) if self.debug: log_debug_model_dispatch( @@ -478,7 +480,9 @@ async def generate_multimodal_response(self, request: ChatCompletionRequest): str: Complete response. """ try: - input_prompt, model_params, parsers_result = await self._build_inference_context(request) + input_prompt, model_params, parsers_result = await self._build_inference_context( + request + ) if self.debug: log_debug_model_dispatch( @@ -783,7 +787,9 @@ async def _prepare_multimodal_request( # Extract only the fields consumed downstream instead of serializing # the entire Pydantic model with model_dump(). - chat_template_kwargs = request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {} + chat_template_kwargs = ( + request.chat_template_kwargs.model_dump() if request.chat_template_kwargs else {} + ) if request.tools: tools = [t.model_dump() for t in request.tools] diff --git a/app/handler/mlx_whisper.py b/app/handler/mlx_whisper.py index e62d1b7c..8501bf8c 100644 --- a/app/handler/mlx_whisper.py +++ b/app/handler/mlx_whisper.py @@ -1,11 +1,12 @@ +from collections.abc import AsyncGenerator import gc +from http import HTTPStatus import json import os import tempfile import time +from typing import Any import uuid -from http import HTTPStatus -from typing import Any, AsyncGenerator, Dict, List, Optional from fastapi import HTTPException from loguru import logger @@ -23,6 +24,7 @@ ) from ..utils.errors import create_error_response + class MLXWhisperHandler: """ Handler class for making requests to the underlying MLX Whisper model service. @@ -34,7 +36,7 @@ class MLXWhisperHandler: def __init__(self, model_path: str, max_concurrency: int = 1): """ Initialize the handler with the specified model path. - + Args: model_path (str): Path to the model directory. max_concurrency (int): Maximum number of concurrent model inference tasks. @@ -48,23 +50,25 @@ def __init__(self, model_path: str, max_concurrency: int = 1): self.inference_worker = InferenceWorker() logger.info(f"Initialized MLXWhisperHandler with model path: {model_path}") - - async def get_models(self) -> List[Dict[str, Any]]: + + async def get_models(self) -> list[dict[str, Any]]: """ Get list of available models with their metadata. """ try: - return [{ - "id": self.model_path, - "object": "model", - "created": self.model_created, - "owned_by": "local" - }] + return [ + { + "id": self.model_path, + "object": "model", + "created": self.model_created, + "owned_by": "local", + } + ] except Exception as e: - logger.error(f"Error getting models: {str(e)}") + logger.error(f"Error getting models: {e!s}") return [] - - async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> None: + + async def initialize(self, queue_config: dict[str, Any] | None = None) -> None: """Initialize the handler and start the inference worker. Parameters @@ -85,13 +89,15 @@ async def initialize(self, queue_config: Optional[Dict[str, Any]] = None) -> Non self.inference_worker.start() logger.info("Initialized MLXWhisperHandler and started inference worker") - async def generate_transcription_response(self, request: TranscriptionRequest) -> TranscriptionResponse: + async def generate_transcription_response( + self, request: TranscriptionRequest + ) -> TranscriptionResponse: """ Generate a transcription response for the given request. """ request_id = f"transcription-{uuid.uuid4()}" temp_file_path = None - + try: request_data = await self._prepare_transcription_request(request) temp_file_path = request_data.get("audio_path") @@ -106,15 +112,13 @@ async def generate_transcription_response(self, request: TranscriptionRequest) - response_data = TranscriptionResponse( text=response["text"], usage=TranscriptionUsageAudio( - type="duration", - seconds=int(calculate_audio_duration(temp_file_path)) - ) + type="duration", seconds=int(calculate_audio_duration(temp_file_path)) + ), ) if request.response_format == TranscriptionResponseFormat.JSON: return response_data - else: - # dump to string for text response - return json.dumps(response_data.model_dump()) + # dump to string for text response + return json.dumps(response_data.model_dump()) finally: # Clean up temporary file if temp_file_path and os.path.exists(temp_file_path): @@ -122,17 +126,15 @@ async def generate_transcription_response(self, request: TranscriptionRequest) - os.unlink(temp_file_path) logger.debug(f"Cleaned up temporary file: {temp_file_path}") except Exception as e: - logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}") + logger.warning(f"Failed to clean up temporary file {temp_file_path}: {e!s}") async def generate_transcription_stream_from_data( - self, - request_data: Dict[str, Any], - response_format: TranscriptionResponseFormat + self, request_data: dict[str, Any], response_format: TranscriptionResponseFormat ) -> AsyncGenerator[str, None]: """ Generate a transcription stream from prepared request data. Yields SSE-formatted chunks with timing information. - + Args: request_data: Prepared request data with audio_path already saved response_format: The response format (json or text) @@ -140,7 +142,7 @@ async def generate_transcription_stream_from_data( request_id = f"transcription-{uuid.uuid4()}" created_time = int(time.time()) temp_file_path = request_data.get("audio_path") - + try: # Set stream mode and submit to inference thread request_data["stream"] = True @@ -164,17 +166,14 @@ async def generate_transcription_stream_from_data( model=self.model_path, choices=[ TranscriptionResponseStreamChoice( - delta=Delta( - content=chunk.get("text", "") - ), - finish_reason=None + delta=Delta(content=chunk.get("text", "")), finish_reason=None ) - ] + ], ) - + # Yield as SSE format yield f"data: {stream_response.model_dump_json()}\n\n" - + # Send final chunk with finish_reason final_response = TranscriptionResponseStream( id=request_id, @@ -182,17 +181,14 @@ async def generate_transcription_stream_from_data( created=created_time, model=self.model_path, choices=[ - TranscriptionResponseStreamChoice( - delta=Delta(content=""), - finish_reason="stop" - ) - ] + TranscriptionResponseStreamChoice(delta=Delta(content=""), finish_reason="stop") + ], ) yield f"data: {final_response.model_dump_json()}\n\n" yield "data: [DONE]\n\n" - + except Exception as e: - logger.error(f"Error during transcription streaming: {str(e)}") + logger.error(f"Error during transcription streaming: {e!s}") raise finally: # Clean up temporary file @@ -201,15 +197,15 @@ async def generate_transcription_stream_from_data( os.unlink(temp_file_path) logger.debug(f"Cleaned up temporary file: {temp_file_path}") except Exception as e: - logger.warning(f"Failed to clean up temporary file {temp_file_path}: {str(e)}") + logger.warning(f"Failed to clean up temporary file {temp_file_path}: {e!s}") async def _save_uploaded_file(self, file) -> str: """ Save the uploaded file to a temporary location. - + Args: file: The uploaded file object. - + Returns: str: Path to the temporary file. """ @@ -218,39 +214,35 @@ async def _save_uploaded_file(self, file) -> str: file_extension = os.path.splitext(file.filename)[1] if file.filename else ".wav" print("file_extension", file_extension) - + # Read file content first (this can only be done once with FastAPI uploads) content = await file.read() - + # Create temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as temp_file: # Write the file contents temp_file.write(content) temp_path = temp_file.name - + logger.debug(f"Saved uploaded file to temporary location: {temp_path}") return temp_path - + except Exception as e: - logger.error(f"Error saving uploaded file: {str(e)}") + logger.error(f"Error saving uploaded file: {e!s}") raise - async def _prepare_transcription_request( - self, - request: TranscriptionRequest - ) -> Dict[str, Any]: + async def _prepare_transcription_request(self, request: TranscriptionRequest) -> dict[str, Any]: """ Prepare a transcription request by parsing model parameters. - + Args: request: TranscriptionRequest object. audio_path: Path to the audio file. - + Returns: Dict containing the request data ready for the model. """ try: - file = request.file file_path = await self._save_uploaded_file(file) @@ -258,41 +250,37 @@ async def _prepare_transcription_request( "audio_path": file_path, "verbose": False, } - + # Add optional parameters if provided if request.temperature is not None: request_data["temperature"] = request.temperature - + if request.language is not None: request_data["language"] = request.language - + if request.prompt is not None: request_data["initial_prompt"] = request.prompt - + # Map additional parameters if they exist decode_options = {} if request.language is not None: decode_options["language"] = request.language - + # Add decode options to request data request_data.update(decode_options) - + logger.debug(f"Prepared transcription request: {request_data}") - + return request_data - + except Exception as e: - logger.error(f"Failed to prepare transcription request: {str(e)}") + logger.error(f"Failed to prepare transcription request: {e!s}") content = create_error_response( - f"Failed to process request: {str(e)}", - "bad_request", - HTTPStatus.BAD_REQUEST + f"Failed to process request: {e!s}", "bad_request", HTTPStatus.BAD_REQUEST ) raise HTTPException(status_code=400, detail=content) - async def transcribe_from_data( - self, request_data: Dict[str, Any] - ) -> TranscriptionResponse: + async def transcribe_from_data(self, request_data: dict[str, Any]) -> TranscriptionResponse: """Run transcription from pre-processed request data. This method is used by ``HandlerProcessProxy`` for IPC: the @@ -330,13 +318,11 @@ async def transcribe_from_data( try: os.unlink(temp_file_path) except Exception as e: - logger.warning( - f"Failed to clean up temp file {temp_file_path}: {e}" - ) + logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}") async def transcribe_stream_from_data( self, - request_data: Dict[str, Any], + request_data: dict[str, Any], ) -> AsyncGenerator[str, None]: """Run streaming transcription from pre-processed request data. @@ -408,11 +394,9 @@ async def transcribe_stream_from_data( try: os.unlink(temp_file_path) except Exception as e: - logger.warning( - f"Failed to clean up temp file {temp_file_path}: {e}" - ) + logger.warning(f"Failed to clean up temp file {temp_file_path}: {e}") - async def get_queue_stats(self) -> Dict[str, Any]: + async def get_queue_stats(self) -> dict[str, Any]: """Get statistics from the inference worker. Returns @@ -432,12 +416,11 @@ async def cleanup(self) -> None: """ try: logger.info("Cleaning up MLXWhisperHandler resources") - if hasattr(self, 'inference_worker'): + if hasattr(self, "inference_worker"): self.inference_worker.stop() # Force garbage collection gc.collect() logger.info("MLXWhisperHandler cleanup completed successfully") except Exception as e: - logger.error(f"Error during MLXWhisperHandler cleanup: {str(e)}") + logger.error(f"Error during MLXWhisperHandler cleanup: {e!s}") raise - diff --git a/app/message_converters/abstract_converter.py b/app/message_converters/abstract_converter.py index 542d1acf..29994a0f 100644 --- a/app/message_converters/abstract_converter.py +++ b/app/message_converters/abstract_converter.py @@ -2,9 +2,10 @@ from typing import Any + class AbstractMessageConverter: """Abstract message converter class that should not be used directly. - + Provided properties and methods should be used in derived classes to convert messages to be compatible with specific model chat templates. """ @@ -13,4 +14,4 @@ def convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any """Convert messages to be compatible with specific model chat templates""" raise NotImplementedError( "AbstractMessageConverter.convert_messages has not been implemented!" - ) \ No newline at end of file + ) diff --git a/app/message_converters/glm4_moe.py b/app/message_converters/glm4_moe.py index 4c1b5b27..b1052df6 100644 --- a/app/message_converters/glm4_moe.py +++ b/app/message_converters/glm4_moe.py @@ -5,17 +5,18 @@ from .abstract_converter import AbstractMessageConverter + class GLM4MoEMessageConverter(AbstractMessageConverter): """GLM4 MoE-specific message format converter""" def convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert message format to be compatible with GLM4 MoE chat templates. - + Parameters ---------- messages : list[dict[str, Any]] List of messages in OpenAI API format. - + Returns ------- list[dict[str, Any]] @@ -32,12 +33,12 @@ def convert_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any def _convert_single_message(self, message: dict[str, Any]) -> dict[str, Any]: """Convert a single message. - + Parameters ---------- message : dict[str, Any] Single message to convert. - + Returns ------- dict[str, Any] @@ -53,7 +54,7 @@ def _convert_single_message(self, message: dict[str, Any]) -> dict[str, Any]: def _convert_tool_calls(self, tool_calls: list[dict[str, Any]]) -> None: """Convert arguments format in tool calls. - + Parameters ---------- tool_calls : list[dict[str, Any]] @@ -69,12 +70,12 @@ def _convert_tool_calls(self, tool_calls: list[dict[str, Any]]) -> None: def _parse_arguments_string(self, arguments_str: str) -> Any: """Parse GLM4 MoE-specific argument string format. - + Parameters ---------- arguments_str : str Arguments in string format. - + Returns ------- Any @@ -83,4 +84,4 @@ def _parse_arguments_string(self, arguments_str: str) -> Any: try: return json.loads(arguments_str) except json.JSONDecodeError: - return arguments_str \ No newline at end of file + return arguments_str diff --git a/app/models/mflux.py b/app/models/mflux.py index 68c7b96d..30a3e74d 100644 --- a/app/models/mflux.py +++ b/app/models/mflux.py @@ -1,21 +1,21 @@ from __future__ import annotations -from loguru import logger -import inspect -from PIL import Image from abc import ABC, abstractmethod -from typing import Any, Callable, Optional +from collections.abc import Callable +import inspect +from typing import Any +from loguru import logger from mflux.models.common.config import ModelConfig -from mflux.models.z_image.variants import ZImageTurbo from mflux.models.fibo.variants.txt2img.fibo import FIBO -from mflux.models.flux.variants.txt2img.flux import Flux1 -from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage from mflux.models.flux.variants.kontext.flux_kontext import Flux1Kontext -from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit -from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein +from mflux.models.flux.variants.txt2img.flux import Flux1 from mflux.models.flux2.variants.edit.flux2_klein_edit import Flux2KleinEdit - +from mflux.models.flux2.variants.txt2img.flux2_klein import Flux2Klein +from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit +from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage +from mflux.models.z_image.variants import ZImageTurbo +from PIL import Image # ----------------------------------------------------------------------------- # Exceptions @@ -24,22 +24,18 @@ class ImageModelError(Exception): """Base exception for image generation model errors.""" - pass class ModelLoadError(ImageModelError): """Raised when model loading fails.""" - pass class ModelGenerationError(ImageModelError): """Raised when image generation fails.""" - pass class InvalidConfigurationError(ImageModelError): """Raised when configuration is invalid.""" - pass # ----------------------------------------------------------------------------- @@ -48,8 +44,8 @@ class InvalidConfigurationError(ImageModelError): def _lora_validate( - lora_paths: Optional[list[str]] | None, - lora_scales: Optional[list[float]] | None, + lora_paths: list[str] | None, + lora_scales: list[float] | None, ) -> None: if (lora_paths is None) != (lora_scales is None): raise InvalidConfigurationError( @@ -69,9 +65,9 @@ def __init__( self, model_type: str, model_config: ModelConfig, - quantize: Optional[int] = None, - lora_paths: Optional[list[str]] = None, - lora_scales: Optional[list[float]] = None, + quantize: int | None = None, + lora_paths: list[str] | None = None, + lora_scales: list[float] | None = None, ) -> None: if quantize is not None and quantize not in (4, 8, 16): raise InvalidConfigurationError( @@ -88,9 +84,9 @@ def __init__( def from_name( cls, config_name: str, - quantize: Optional[int] = None, - lora_paths: Optional[list[str]] = None, - lora_scales: Optional[list[float]] = None, + quantize: int | None = None, + lora_paths: list[str] | None = None, + lora_scales: list[float] | None = None, ) -> ModelConfiguration: if config_name not in _CONFIG_REGISTRY: available = ", ".join(_CONFIG_REGISTRY.keys()) @@ -147,15 +143,12 @@ def is_loaded(self) -> bool: @abstractmethod def _load_model(self) -> None: """Load the specific model implementation.""" - pass def _generate_image(self, prompt: str, seed: int = 42, **kwargs: Any) -> Image.Image: sig = inspect.signature(self._model.generate_image) valid = set(sig.parameters.keys()) filtered = {k: v for k, v in kwargs.items() if k in valid} - result = self._model.generate_image( - prompt=prompt, seed=seed, **filtered - ) + result = self._model.generate_image(prompt=prompt, seed=seed, **filtered) return result.image def __call__(self, prompt: str, seed: int = 42, **kwargs: Any) -> Image.Image: @@ -262,9 +255,9 @@ def __init__( self, model_path: str, config_name: str, - quantize: Optional[int] = None, - lora_paths: Optional[list[str]] = None, - lora_scales: Optional[list[float]] = None, + quantize: int | None = None, + lora_paths: list[str] | None = None, + lora_scales: list[float] | None = None, ) -> None: self.model_path = model_path self.config_name = config_name @@ -332,4 +325,4 @@ def is_loaded(self) -> bool: guidance=1.0, image_paths=[image_path], ) - image.save("examples/result.png") \ No newline at end of file + image.save("examples/result.png") diff --git a/app/models/mlx_embeddings.py b/app/models/mlx_embeddings.py index 84454863..70620bd9 100644 --- a/app/models/mlx_embeddings.py +++ b/app/models/mlx_embeddings.py @@ -1,12 +1,13 @@ import gc + import mlx.core as mx from mlx_embeddings.utils import load -from typing import List, Optional + class MLX_Embeddings: """ A wrapper class for MLX Embeddings that handles memory management to prevent leaks. - + This class provides a unified interface for generating embeddings from text inputs, with proper cleanup of MLX arrays and memory management. """ @@ -14,26 +15,26 @@ class MLX_Embeddings: def __init__(self, model_path: str): """ Initialize the MLX_Embeddings model. - + Args: model_name (str): Name of the model to load. - + Raises: ValueError: If model loading fails. """ try: self.model, self.tokenizer = load(model_path) except Exception as e: - raise ValueError(f"Error loading model: {str(e)}") - - def _get_embeddings(self, texts: List[str], max_length: int = 512) -> mx.array: + raise ValueError(f"Error loading model: {e!s}") + + def _get_embeddings(self, texts: list[str], max_length: int = 512) -> mx.array: """ Get embeddings for a list of texts with proper memory management. - + Args: texts: List of text inputs max_length: Maximum sequence length for tokenization - + Returns: MLX array of embeddings """ @@ -42,25 +43,20 @@ def _get_embeddings(self, texts: List[str], max_length: int = 512) -> mx.array: try: # Tokenize inputs tokenizer = getattr(self.tokenizer, "_tokenizer", self.tokenizer) - + inputs = tokenizer( - texts, - return_tensors="np", - padding=True, - truncation=True, - max_length=max_length + texts, return_tensors="np", padding=True, truncation=True, max_length=max_length ) - + # Generate embeddings outputs = self.model( - mx.array(inputs["input_ids"]), - attention_mask=mx.array(inputs["attention_mask"]) + mx.array(inputs["input_ids"]), attention_mask=mx.array(inputs["attention_mask"]) ).text_embeds - + # Return a copy to ensure the result persists after cleanup return mx.array(outputs) - - except Exception as e: + + except Exception: # Clean up on error self._cleanup_arrays(inputs, outputs) raise @@ -75,25 +71,25 @@ def _cleanup_arrays(self, *arrays): try: if isinstance(array, dict): for key, value in array.items(): - if hasattr(value, 'nbytes'): + if hasattr(value, "nbytes"): del value - elif hasattr(array, 'nbytes'): + elif hasattr(array, "nbytes"): del array except: pass - + # Clear MLX cache and force garbage collection mx.clear_cache() gc.collect() - def __call__(self, texts: List[str], max_length: int = 512) -> List[List[float]]: + def __call__(self, texts: list[str], max_length: int = 512) -> list[list[float]]: """ Generate embeddings for a list of texts. - + Args: texts: List of text inputs max_length: Maximum sequence length for tokenization - + Returns: List of embedding vectors as float lists """ @@ -106,7 +102,7 @@ def __call__(self, texts: List[str], max_length: int = 512) -> List[List[float]] mx.clear_cache() gc.collect() return result - except Exception as e: + except Exception: # Clean up on error mx.clear_cache() gc.collect() @@ -116,15 +112,15 @@ def cleanup(self): """Explicitly cleanup resources.""" try: # Clear any cached model outputs - if hasattr(self, 'model'): + if hasattr(self, "model"): del self.model - if hasattr(self, 'tokenizer'): + if hasattr(self, "tokenizer"): del self.tokenizer - + # Clear MLX cache and force garbage collection mx.clear_cache() gc.collect() - except Exception as e: + except Exception: # Log cleanup errors but don't raise pass @@ -132,6 +128,7 @@ def __del__(self): """Destructor to ensure cleanup on object deletion.""" self.cleanup() + if __name__ == "__main__": model_path = "mlx-community/all-MiniLM-L6-v2-4bit" model = MLX_Embeddings(model_path) @@ -141,4 +138,4 @@ def __del__(self): print(f"Generated embeddings shape: {len(embeddings)} x {len(embeddings[0])}") finally: # Explicit cleanup - model.cleanup() \ No newline at end of file + model.cleanup() diff --git a/app/models/mlx_lm.py b/app/models/mlx_lm.py index 82d3595d..063231f1 100644 --- a/app/models/mlx_lm.py +++ b/app/models/mlx_lm.py @@ -1,18 +1,18 @@ +from collections.abc import Generator +from dataclasses import dataclass import os -import mlx.core as mx +from typing import Any + from loguru import logger +import mlx.core as mx +from mlx_lm.generate import GenerationResponse, stream_generate +from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.utils import load -from mlx_lm.generate import ( - stream_generate -) -from dataclasses import dataclass -from mlx_lm.generate import GenerationResponse from outlines.processors import JSONLogitsProcessor -from mlx_lm.models.cache import make_prompt_cache -from mlx_lm.sample_utils import make_sampler, make_logits_processors -from ..utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer + from ..utils.debug_logging import log_debug_chat_template -from typing import List, Dict, Union, Generator, Any +from ..utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer DEFAULT_TEMPERATURE = float(os.getenv("DEFAULT_TEMPERATURE", "0.7")) DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95")) @@ -24,6 +24,7 @@ DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "1000000")) DEFAULT_REPETITION_CONTEXT_SIZE = int(os.getenv("DEFAULT_REPETITION_CONTEXT_SIZE", "20")) + @dataclass class CompletionResponse: """ @@ -40,24 +41,36 @@ class CompletionResponse: """ text: str = None - tokens: List[int] = None + tokens: list[int] = None peak_memory: float = None generation_tps: float = None prompt_tps: float = None prompt_tokens: int = None generation_tokens: int = None + class MLX_LM: """ A wrapper class for MLX Language Model that handles both streaming and non-streaming inference. - + This class provides a unified interface for generating text responses from text prompts, supporting both streaming and non-streaming modes. """ - def __init__(self, model_path: str, draft_model_path: str = None, num_draft_tokens: int = 2, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None, debug: bool = False): + def __init__( + self, + model_path: str, + draft_model_path: str = None, + num_draft_tokens: int = 2, + context_length: int | None = None, + trust_remote_code: bool = False, + chat_template_file: str = None, + debug: bool = False, + ): try: - self.model, self.tokenizer = load(model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code}) + self.model, self.tokenizer = load( + model_path, lazy=False, tokenizer_config={"trust_remote_code": trust_remote_code} + ) self.context_length = context_length self.draft_model = None self.draft_tokenizer = None @@ -72,21 +85,29 @@ def __init__(self, model_path: str, draft_model_path: str = None, num_draft_toke if chat_template_file: if not os.path.exists(chat_template_file): raise ValueError(f"Chat template file {chat_template_file} does not exist") - with open(chat_template_file, "r") as f: + with open(chat_template_file) as f: template_content = f.read() self.tokenizer.chat_template = template_content if self.debug: - log_debug_chat_template(chat_template_file=chat_template_file, template_content=template_content) + log_debug_chat_template( + chat_template_file=chat_template_file, template_content=template_content + ) except Exception as e: - raise ValueError(f"Error loading model: {str(e)}") + raise ValueError(f"Error loading model: {e!s}") def _load_draft_model(self, draft_model_path: str, trust_remote_code: bool) -> None: try: - self.draft_model, self.draft_tokenizer = load(draft_model_path, lazy=False, tokenizer_config = {"trust_remote_code": trust_remote_code}) - self.context_length = None # speculative decoding does not support context length, should be set to None + self.draft_model, self.draft_tokenizer = load( + draft_model_path, + lazy=False, + tokenizer_config={"trust_remote_code": trust_remote_code}, + ) + self.context_length = ( + None # speculative decoding does not support context length, should be set to None + ) self._validate_draft_tokenizer() except Exception as e: - raise ValueError(f"Error loading draft model: {str(e)}") + raise ValueError(f"Error loading draft model: {e!s}") def _validate_draft_tokenizer(self) -> None: if self.draft_tokenizer.vocab_size != self.tokenizer.vocab_size: @@ -95,7 +116,7 @@ def _validate_draft_tokenizer(self) -> None: "Speculative decoding may not work as expected." ) - def create_prompt_cache(self) -> List[Any]: + def create_prompt_cache(self) -> list[Any]: cache = make_prompt_cache(self.model, max_kv_size=self.context_length) if self.draft_model: cache += make_prompt_cache(self.draft_model, max_kv_size=self.context_length) @@ -104,7 +125,9 @@ def create_prompt_cache(self) -> List[Any]: def get_model_type(self) -> str: return self.model_type - def create_input_prompt(self, messages: List[Dict[str, str]], chat_template_kwargs: Dict[str, Any]) -> str: + def create_input_prompt( + self, messages: list[dict[str, str]], chat_template_kwargs: dict[str, Any] + ) -> str: use_partial = chat_template_kwargs.pop("_partial_mode", False) return self.tokenizer.apply_chat_template( @@ -115,16 +138,12 @@ def create_input_prompt(self, messages: List[Dict[str, str]], chat_template_kwar **chat_template_kwargs, ) - def encode_prompt(self, input_prompt: str) -> List[int]: + def encode_prompt(self, input_prompt: str) -> list[int]: return self.tokenizer.encode(input_prompt) def __call__( - self, - input_ids: List[int], - prompt_cache: List[Any] = None, - stream: bool = False, - **kwargs - ) -> Union[CompletionResponse, Generator[GenerationResponse, None, None]]: + self, input_ids: list[int], prompt_cache: list[Any] = None, stream: bool = False, **kwargs + ) -> CompletionResponse | Generator[GenerationResponse, None, None]: """ Generate text response from the model. @@ -137,6 +156,7 @@ def __call__( - seed: Random seed (default: 0) - max_tokens: Maximum number of tokens to generate (default: 256) """ + # Set default parameters if not provided (use 'is not None' to preserve valid 0 values) def _get(key, default): v = kwargs.get(key) @@ -172,16 +192,14 @@ def _get(key, default): logits_processors = make_logits_processors( logit_bias=logit_bias, repetition_penalty=repetition_penalty, - repetition_context_size=repetition_context_size + repetition_context_size=repetition_context_size, ) json_schema = kwargs.get("schema") if json_schema: logits_processors.append( JSONLogitsProcessor( - schema=json_schema, - tokenizer=self.outlines_tokenizer, - tensor_library_name="mlx" + schema=json_schema, tokenizer=self.outlines_tokenizer, tensor_library_name="mlx" ) ) @@ -189,12 +207,10 @@ def _get(key, default): # None or negative values (e.g., -1) result in non-deterministic generation if seed and seed >= 0: mx.random.seed(seed) - + prompt_progress_callback = kwargs.get("prompt_progress_callback") - - sampler = make_sampler( - **sampler_kwargs - ) + + sampler = make_sampler(**sampler_kwargs) stream_response = stream_generate( self.model, @@ -206,7 +222,7 @@ def _get(key, default): num_draft_tokens=self.num_draft_tokens, prompt_cache=prompt_cache, logits_processors=logits_processors, - prompt_progress_callback=prompt_progress_callback + prompt_progress_callback=prompt_progress_callback, ) if stream: return stream_response @@ -219,7 +235,7 @@ def _get(key, default): tokens.append(chunk.token) if chunk.finish_reason: final_chunk = chunk - + return CompletionResponse( text=text, tokens=tokens, @@ -228,4 +244,4 @@ def _get(key, default): prompt_tps=final_chunk.prompt_tps, prompt_tokens=final_chunk.prompt_tokens, generation_tokens=final_chunk.generation_tokens, - ) \ No newline at end of file + ) diff --git a/app/models/mlx_vlm.py b/app/models/mlx_vlm.py index 39dfbfdd..968dae29 100644 --- a/app/models/mlx_vlm.py +++ b/app/models/mlx_vlm.py @@ -1,14 +1,17 @@ +from collections.abc import Generator +from dataclasses import dataclass import os -import torch +from typing import Any + import mlx.core as mx -from dataclasses import dataclass -from typing import List, Dict, Union, Generator, Any -from mlx_vlm.models.cache import make_prompt_cache from mlx_vlm import load, stream_generate +from mlx_vlm.models.cache import make_prompt_cache from mlx_vlm.video_generate import process_vision_info -from ..utils.prompt_cache import LRUPromptCache from outlines.processors import JSONLogitsProcessor +import torch + from ..utils.outlines_transformer_tokenizer import OutlinesTransformerTokenizer +from ..utils.prompt_cache import LRUPromptCache # Default model parameters DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "100000")) @@ -17,6 +20,7 @@ DEFAULT_SEED = int(os.getenv("DEFAULT_SEED", "0")) DEFAULT_REPETITION_CONTEXT_SIZE = int(os.getenv("DEFAULT_REPETITION_CONTEXT_SIZE", "20")) + @dataclass class CompletionResponse: """ @@ -33,97 +37,99 @@ class CompletionResponse: """ text: str = None - tokens: List[int] = None + tokens: list[int] = None peak_memory: float = None generation_tps: float = None prompt_tps: float = None prompt_tokens: int = None generation_tokens: int = None + class MLX_VLM: """ A wrapper class for MLX Multimodal Model that handles both streaming and non-streaming inference. - + This class provides a unified interface for generating text responses from images and text prompts, supporting both streaming and non-streaming modes. """ - - def __init__(self, model_path: str, context_length: int | None = None, trust_remote_code: bool = False, chat_template_file: str = None): + + def __init__( + self, + model_path: str, + context_length: int | None = None, + trust_remote_code: bool = False, + chat_template_file: str = None, + ): """ Initialize the MLX_VLM model. - + Args: model_path (str): Path to the model directory containing model weights and configuration. context_length (int | None): Maximum context length for the model. If None, uses model default. trust_remote_code (bool): Enable trust_remote_code when loading models. Defaults to False. + Raises: ValueError: If model loading fails. """ try: - self.model, self.processor = load(model_path, lazy=False, trust_remote_code=trust_remote_code) + self.model, self.processor = load( + model_path, lazy=False, trust_remote_code=trust_remote_code + ) self.config = self.model.config self.context_length = context_length self.outlines_tokenizer = OutlinesTransformerTokenizer(self.processor.tokenizer) if chat_template_file: if not os.path.exists(chat_template_file): raise ValueError(f"Chat template file {chat_template_file} does not exist") - with open(chat_template_file, "r") as f: + with open(chat_template_file) as f: self.processor.chat_template = f.read() except Exception as e: - raise ValueError(f"Error loading model: {str(e)}") + raise ValueError(f"Error loading model: {e!s}") def _is_video_model(self): - return hasattr(self.config, "video_token_id") or hasattr( - self.config, "video_token_index" - ) + return hasattr(self.config, "video_token_id") or hasattr(self.config, "video_token_index") def get_model_type(self): return self.config.model_type - def create_prompt_cache(self) -> List[Any]: + def create_prompt_cache(self) -> list[Any]: return make_prompt_cache(self.model.language_model, max_kv_size=self.context_length) - def create_input_prompt(self, messages: List[Dict[str, str]], chat_template_kwargs: Dict[str, Any]) -> str: + def create_input_prompt( + self, messages: list[dict[str, str]], chat_template_kwargs: dict[str, Any] + ) -> str: chat_template_kwargs.pop("_partial_mode", None) return self.processor.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - **chat_template_kwargs + messages, tokenize=False, add_generation_prompt=True, **chat_template_kwargs ) - def create_inputs(self, text: str, images: List[str] = None, videos: List[str] = None) -> Dict[str, Any]: + def create_inputs( + self, text: str, images: list[str] = None, videos: list[str] = None + ) -> dict[str, Any]: return self.processor( - text=[text], - images=images, - videos=videos, - padding=True, - return_tensors="pt" + text=[text], images=images, videos=videos, padding=True, return_tensors="pt" ) def __call__( - self, - prompt: str, - prompt_cache: List[Any] = None, - stream: bool = False, - **kwargs - ) -> Union[CompletionResponse, Generator[str, None, None]]: + self, prompt: str, prompt_cache: list[Any] = None, stream: bool = False, **kwargs + ) -> CompletionResponse | Generator[str, None, None]: """ Generate text response from images and messages. - + Args: prompt (str): The input prompt text. prompt_cache (List[Any], optional): Prompt cache for faster inference. stream (bool, optional): Whether to stream the response. Defaults to False. **kwargs: Additional model parameters (temperature, max_tokens, etc.) - schema (dict, optional): JSON schema for structured output generation. - + Returns: - Union[CompletionResponse, Generator[str, None, None]]: + Union[CompletionResponse, Generator[str, None, None]]: - If stream=False: Complete response as CompletionResponse - If stream=True: Generator yielding response chunks """ + def _get(key, default): v = kwargs.get(key) return default if v is None else v @@ -139,9 +145,7 @@ def _get(key, default): if json_schema: logits_processors.append( JSONLogitsProcessor( - schema=json_schema, - tokenizer=self.outlines_tokenizer, - tensor_library_name="mlx" + schema=json_schema, tokenizer=self.outlines_tokenizer, tensor_library_name="mlx" ) ) @@ -153,10 +157,12 @@ def _get(key, default): "max_tokens": max_tokens, "temperature": _get("temperature", DEFAULT_TEMPERATURE), "repetition_penalty": kwargs.get("repetition_penalty"), - "repetition_context_size": _get("repetition_context_size", DEFAULT_REPETITION_CONTEXT_SIZE), + "repetition_context_size": _get( + "repetition_context_size", DEFAULT_REPETITION_CONTEXT_SIZE + ), "top_p": _get("top_p", DEFAULT_TOP_P), } - + model_params.update(sampling_params) response_generator = stream_generate( @@ -164,8 +170,8 @@ def _get(key, default): self.processor, prompt=prompt, prompt_cache=prompt_cache, - logits_processors = logits_processors, - **model_params + logits_processors=logits_processors, + **model_params, ) if stream: @@ -199,19 +205,21 @@ def _get(key, default): model = MLX_VLM(model_path, context_length=2048) - tools = [{ - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the weather for a given city", - "parameters": { - "type": "object", - "properties": { - "city": {"type": "string", "description": "The city to get the weather for"} - } + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a given city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city to get the weather for"} + }, + }, + "required": ["city"], }, - "required": ["city"] - }} + } ] chat_template_kwargs = { "tools": tools, @@ -221,15 +229,9 @@ def _get(key, default): { "role": "user", "content": [ - { - "type": "text", - "text": "What is the weather in New York?" - }, - { - "type": "image", - "image": image_path - } - ] + {"type": "text", "text": "What is the weather in New York?"}, + {"type": "image", "image": image_path}, + ], } ] prompt_cache = LRUPromptCache() @@ -252,11 +254,11 @@ def _get(key, default): "repetition_penalty": None, "repetition_context_size": 20, } - + inputs.update(sampling_params) inputs["schema"] = None response = model(input_prompt, stream=False, **inputs) - - print("RESPONSE: ", response) \ No newline at end of file + + print("RESPONSE: ", response) diff --git a/app/models/mlx_whisper.py b/app/models/mlx_whisper.py index 61677337..f4afd89c 100644 --- a/app/models/mlx_whisper.py +++ b/app/models/mlx_whisper.py @@ -1,7 +1,8 @@ -import librosa -import numpy as np from functools import lru_cache + +import librosa from mlx_whisper.transcribe import transcribe +import numpy as np SAMPLING_RATE = 16000 CHUNK_SIZE = 30 @@ -13,12 +14,14 @@ def load_audio(fname): a, _ = librosa.load(fname, sr=SAMPLING_RATE, dtype=np.float32) return a + @lru_cache(maxsize=32) def calculate_audio_duration(audio_path: str) -> int: """Calculate the duration of the audio file in seconds.""" audio = load_audio(audio_path) return len(audio) / SAMPLING_RATE + class MLX_Whisper: def __init__(self, model_path: str): self.model_path = model_path @@ -28,32 +31,32 @@ def _transcribe_generator(self, audio_path: str, **kwargs): # Load the audio file audio = load_audio(audio_path) duration = calculate_audio_duration(audio_path) - + beg = 0.0 while beg < duration: # Calculate chunk boundaries chunk_end = min(beg + CHUNK_SIZE, duration) - + # Extract audio chunk beg_samples = int(beg * SAMPLING_RATE) end_samples = int(chunk_end * SAMPLING_RATE) audio_chunk = audio[beg_samples:end_samples] - + # Transcribe chunk result = transcribe(audio_chunk, path_or_hf_repo=self.model_path, **kwargs) - + # Add timing information result["chunk_start"] = beg result["chunk_end"] = chunk_end - + yield result - + beg += CHUNK_SIZE def __call__(self, audio_path: str, stream: bool = False, **kwargs): """ Transcribe audio file. - + Args: audio_path: Path to audio file stream: If True, yields chunks. If False, transcribes entire file at once. @@ -61,13 +64,12 @@ def __call__(self, audio_path: str, stream: bool = False, **kwargs): """ if stream: return self._transcribe_generator(audio_path, **kwargs) - else: - return transcribe(audio_path, path_or_hf_repo=self.model_path, **kwargs) - - + return transcribe(audio_path, path_or_hf_repo=self.model_path, **kwargs) + + if __name__ == "__main__": model = MLX_Whisper("mlx-community/whisper-tiny") # Non-streaming (fastest for most use cases) result = model("examples/audios/podcast.wav", stream=True) for chunk in result: - print(f"[{chunk['chunk_start']:.1f}s - {chunk['chunk_end']:.1f}s]: {chunk['text']}") \ No newline at end of file + print(f"[{chunk['chunk_start']:.1f}s - {chunk['chunk_end']:.1f}s]: {chunk['text']}") diff --git a/app/parsers/abstract_parser.py b/app/parsers/abstract_parser.py index aecdf927..a5331626 100644 --- a/app/parsers/abstract_parser.py +++ b/app/parsers/abstract_parser.py @@ -248,14 +248,13 @@ def extract_tool_calls_streaming(self, chunk: str) -> tuple[dict[str, list] | st - extracted_content: Tool calls dict, passthrough chunk, or None - is_complete: True if chunk should be sent, False if buffering """ + def _merge_content_payload( payload: dict[str, list] | None, leading_content: str = "", trailing_content: str = "", ) -> dict[str, list | str]: - merged: dict[str, list | str] = ( - dict(payload) if isinstance(payload, dict) else {} - ) + merged: dict[str, list | str] = dict(payload) if isinstance(payload, dict) else {} pieces: list[str] = [] if leading_content: pieces.append(leading_content) diff --git a/app/parsers/function_parameter.py b/app/parsers/function_parameter.py index d56d28f3..3a1538bf 100644 --- a/app/parsers/function_parameter.py +++ b/app/parsers/function_parameter.py @@ -64,7 +64,9 @@ def __init__( ) @staticmethod - def _coerce_parameter_value(param_value: str) -> str | int | float | bool | list[Any] | dict[str, Any]: + def _coerce_parameter_value( + param_value: str, + ) -> str | int | float | bool | list[Any] | dict[str, Any]: """Parse tool argument values as JSON when possible.""" try: loaded = json.loads(param_value) @@ -121,7 +123,9 @@ def _extract_tool_calls_permissive(self, model_output: str) -> list[dict[str, st function_content = block[function_content_start:function_close_idx] arguments: dict[str, str | int | float | bool | list[Any] | dict[str, Any]] = {} - for param_name_raw, param_value_raw in self.permissive_parameter_regex.findall(function_content): + for param_name_raw, param_value_raw in self.permissive_parameter_regex.findall( + function_content + ): param_name = param_name_raw.strip() param_value = param_value_raw.strip() arguments[param_name] = self._coerce_parameter_value(param_value) diff --git a/app/parsers/functiongemma.py b/app/parsers/functiongemma.py index 97df74eb..d0bcf779 100644 --- a/app/parsers/functiongemma.py +++ b/app/parsers/functiongemma.py @@ -1,7 +1,7 @@ from __future__ import annotations -import re import json +import re from .abstract_parser import AbstractToolParser @@ -32,12 +32,12 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: """Extract tool calls from complete model output. - + Parameters ---------- model_output : str Complete model output containing tool calls. - + Returns ------- dict[str, list] | None @@ -46,9 +46,7 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: """ matches = self.tool_call_regex.findall(model_output) if not matches: - return { - "content": model_output - } + return {"content": model_output} tool_calls = [] for match in matches: function_name = match[0] @@ -59,4 +57,3 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: return { "tool_calls": tool_calls, } - diff --git a/app/parsers/glm4_moe.py b/app/parsers/glm4_moe.py index 1ef433be..1fa0b54e 100644 --- a/app/parsers/glm4_moe.py +++ b/app/parsers/glm4_moe.py @@ -1,7 +1,7 @@ from __future__ import annotations -import re import json +import re from .abstract_parser import AbstractToolParser from .hermes import HermesReasoningParser @@ -19,13 +19,15 @@ class GLM4MoEReasoningParser(HermesReasoningParser): reasoning_content """ - def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None: + def __init__( + self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE + ) -> None: """Initialize the Hermes4 reasoning parser with appropriate regex patterns.""" super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close) - + def respects_enable_thinking(self) -> bool: """Check if the reasoning parser respects the enable_thinking flag. - + Returns ------- bool @@ -52,7 +54,7 @@ class GLM4MoEToolParser(AbstractToolParser): def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> None: """Initialize the GLM4 MoE tool parser with appropriate regex patterns.""" super().__init__(tool_open=tool_open, tool_close=tool_close) - + self.func_call_regex = re.compile(r".*?", re.DOTALL) self.func_detail_regex = re.compile( r"(.*?)(.*?)?", re.DOTALL @@ -64,12 +66,12 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: """Extract tool calls from complete model output. - + Parameters ---------- model_output : str Complete model output containing tool calls in JSON format. - + Returns ------- dict[str, list] | None @@ -78,9 +80,7 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: """ matches = self.func_call_regex.findall(model_output) if not matches: - return { - "content": model_output - } + return {"content": model_output} tool_calls = [] for match in matches: tc_detail = self.func_detail_regex.search(match) @@ -92,10 +92,7 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: arg_key = key.strip() arg_value = value.strip() arg_dct[arg_key] = arg_value - tool_calls.append({ - "name": tc_name.strip(), - "arguments": json.dumps(arg_dct, ensure_ascii=False) - }) - return { - "tool_calls": tool_calls - } + tool_calls.append( + {"name": tc_name.strip(), "arguments": json.dumps(arg_dct, ensure_ascii=False)} + ) + return {"tool_calls": tool_calls} diff --git a/app/parsers/harmony.py b/app/parsers/harmony.py index 838600c5..ad07b71e 100644 --- a/app/parsers/harmony.py +++ b/app/parsers/harmony.py @@ -1,26 +1,26 @@ from __future__ import annotations from enum import Enum -from openai_harmony import ( - load_harmony_encoding, - HarmonyEncodingName, - StreamableParser, - Role -) + +from openai_harmony import HarmonyEncodingName, Role, StreamableParser, load_harmony_encoding + class ChannelType(Enum): """Enumeration of harmony channel types.""" ANALYSIS = "analysis" - COMMENTARY = "commentary" + COMMENTARY = "commentary" FINAL = "final" + class ToolParserState(Enum): """Enumeration of parser states.""" + NORMAL = "normal" FOUND_ARGUMENTS = "found_arguments" END_STREAM = "end_stream" + class HarmonyParser: """Parser for Harmony encoding.""" @@ -39,7 +39,7 @@ def parse(self, text: str) -> dict[str, str] | None: """ if self.end_tool_chunk in text: end_tool_index = text.find(self.end_tool_chunk) - text = text[:end_tool_index + len(self.end_tool_chunk)] + text = text[: end_tool_index + len(self.end_tool_chunk)] result = { "content": None, @@ -47,24 +47,28 @@ def parse(self, text: str) -> dict[str, str] | None: "reasoning_content": None, } tokens = self.encoding.encode(text, allowed_special="all") - parsed_messages = self.encoding.parse_messages_from_completion_tokens(tokens, role=Role.ASSISTANT) + parsed_messages = self.encoding.parse_messages_from_completion_tokens( + tokens, role=Role.ASSISTANT + ) for message in parsed_messages: if message.channel == ChannelType.ANALYSIS.value: result["reasoning_content"] = message.content[0].text elif message.channel == ChannelType.COMMENTARY.value: - result["tool_calls"].append({ - "name": message.recipient.replace("functions.", ""), - "arguments": message.content[0].text - }) + result["tool_calls"].append( + { + "name": message.recipient.replace("functions.", ""), + "arguments": message.content[0].text, + } + ) elif message.channel == ChannelType.FINAL.value: result["content"] = message.content[0].text return result - + def _build_result( - self, - reasoning_contents: list[str], - tool_calls: list[dict[str, str]] | None, - contents: list[str] + self, + reasoning_contents: list[str], + tool_calls: list[dict[str, str]] | None, + contents: list[str], ) -> dict[str, str | list | None]: """Build the result dictionary from accumulated content.""" return { @@ -72,7 +76,7 @@ def _build_result( "tool_calls": tool_calls, "content": "".join(contents) or None, } - + def parse_streaming(self, chunk: str) -> tuple[dict[str, str | list | None] | None, bool]: """Parse the chunk and return the parsed content.""" if self.state == ToolParserState.END_STREAM: @@ -85,9 +89,9 @@ def parse_streaming(self, chunk: str) -> tuple[dict[str, str | list | None] | No # Check for end marker and truncate if self.end_tool_chunk in chunk: end_tool_index = chunk.find(self.end_tool_chunk) - chunk = chunk[:end_tool_index + len(self.end_tool_chunk)] + chunk = chunk[: end_tool_index + len(self.end_tool_chunk)] end_stream_state = True - + # Process chunk tokens chunk_tokens = self.encoding.encode(chunk, allowed_special="all") for chunk_token in chunk_tokens: @@ -115,15 +119,14 @@ def parse_streaming(self, chunk: str) -> tuple[dict[str, str | list | None] | No # Handle end of stream if end_stream_state: - tool_calls = [{ - "name": self.function_name_buffer, - "arguments": "".join(self.arguments_buffer) - }] + tool_calls = [ + {"name": self.function_name_buffer, "arguments": "".join(self.arguments_buffer)} + ] self.arguments_buffer = [] self.function_name_buffer = "" self.state = ToolParserState.END_STREAM return self._build_result(reasoning_contents, tool_calls, contents), True - + return self._build_result(reasoning_contents, None, contents), False def handle_parse_streaming_end(self) -> tuple[dict[str, str | list | None] | None, bool]: diff --git a/app/parsers/hermes.py b/app/parsers/hermes.py index a225a5ff..3452a230 100644 --- a/app/parsers/hermes.py +++ b/app/parsers/hermes.py @@ -1,7 +1,7 @@ from __future__ import annotations -import re import json +import re from .abstract_parser import ( AbstractReasoningParser, diff --git a/app/parsers/kimi_k2.py b/app/parsers/kimi_k2.py index 455933b3..bcedfe08 100644 --- a/app/parsers/kimi_k2.py +++ b/app/parsers/kimi_k2.py @@ -1,8 +1,9 @@ from __future__ import annotations -import re -import json from enum import Enum +import json +import re + from .hermes import HermesToolParser TOOL_CALL_SECTION_BEGIN = "<|tool_calls_section_begin|>" @@ -18,17 +19,20 @@ class KimiK2ToolState(Enum): NORMAL = "normal" FOUND_TOOL_SECTION = "found_tool_section" + class KimiK2ToolParser(HermesToolParser): """Kimi K2 tool parser. - + Handles tool calls in format: <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "New York"}<|tool_call_end|><|tool_calls_section_end|> """ - def __init__(self, tool_open: str = TOOL_CALL_SECTION_BEGIN, tool_close: str = TOOL_CALL_SECTION_END) -> None: + def __init__( + self, tool_open: str = TOOL_CALL_SECTION_BEGIN, tool_close: str = TOOL_CALL_SECTION_END + ) -> None: """Initialize Solar Open tool parser.""" super().__init__(tool_open=tool_open, tool_close=tool_close) - + self.state = KimiK2ToolState.NORMAL self.tool_call_section_regex = re.compile( re.escape(self.tool_open) + r"(.*?)" + re.escape(self.tool_close), @@ -83,4 +87,4 @@ def extract_tool_calls(self, tool_output: str) -> dict[str, list] | None: tool_calls.append({"name": name, "arguments": json.dumps(arguments)}) if not tool_calls: return None - return {"tool_calls": tool_calls} \ No newline at end of file + return {"tool_calls": tool_calls} diff --git a/app/parsers/longcat_flash_lite.py b/app/parsers/longcat_flash_lite.py index 138b2c74..857fd549 100644 --- a/app/parsers/longcat_flash_lite.py +++ b/app/parsers/longcat_flash_lite.py @@ -1,11 +1,13 @@ from __future__ import annotations import re + from .glm4_moe import GLM4MoEToolParser TOOL_OPEN = "" TOOL_CLOSE = "" + class LongCatFlashLiteToolParser(GLM4MoEToolParser): """Tool parser for LongCat Flash Lite model's tool response format. @@ -22,7 +24,7 @@ class LongCatFlashLiteToolParser(GLM4MoEToolParser): def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> None: """Initialize the LongCat Flash Lite tool parser with appropriate regex patterns.""" super().__init__(tool_open=tool_open, tool_close=tool_close) - + self.func_call_regex = re.compile(r".*?", re.DOTALL) self.func_detail_regex = re.compile( r"(.*?)(.*?)?", re.DOTALL @@ -30,4 +32,4 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> self.func_arg_regex = re.compile( r"(.*?)(?:\\n|\s)*(.*?)", re.DOTALL, - ) \ No newline at end of file + ) diff --git a/app/parsers/minimax_m2.py b/app/parsers/minimax_m2.py index ca8a4cb0..5f0076b9 100644 --- a/app/parsers/minimax_m2.py +++ b/app/parsers/minimax_m2.py @@ -24,15 +24,13 @@ class MiniMaxM2ToolParser(GLM4MoEToolParser): def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> None: """Initialize the MiniMax M2 tool parser with appropriate regex patterns.""" super().__init__(tool_open=tool_open, tool_close=tool_close) - + self.func_call_regex = re.compile( r"(.*?)", re.DOTALL ) - + # Regex patterns for parsing MiniMax tool calls - self.func_detail_regex = re.compile( - r'(.*)', re.DOTALL - ) + self.func_detail_regex = re.compile(r'(.*)', re.DOTALL) self.func_arg_regex = re.compile( r'(.*?)', re.DOTALL - ) \ No newline at end of file + ) diff --git a/app/parsers/qwen3.py b/app/parsers/qwen3.py index 7a7bc224..12a252bd 100644 --- a/app/parsers/qwen3.py +++ b/app/parsers/qwen3.py @@ -13,16 +13,18 @@ class Qwen3ReasoningParser(HermesReasoningParser): reasoning_content """ - def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None: + def __init__( + self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE + ) -> None: """Initialize the Qwen3 reasoning parser with appropriate regex patterns.""" super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close) def respects_enable_thinking(self) -> bool: """Check if the reasoning parser respects the enable_thinking flag. - + Returns ------- bool True if the reasoning parser respects the enable_thinking flag, False otherwise. """ - return True \ No newline at end of file + return True diff --git a/app/parsers/qwen3_5.py b/app/parsers/qwen3_5.py index e1a7263c..3b15564f 100644 --- a/app/parsers/qwen3_5.py +++ b/app/parsers/qwen3_5.py @@ -13,16 +13,18 @@ class Qwen35ReasoningParser(Qwen3MoEReasoningParser): reasoning_content """ - def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None: + def __init__( + self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE + ) -> None: """Initialize the Qwen3.5 reasoning parser with appropriate regex patterns.""" super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close) def respects_enable_thinking(self) -> bool: """Check if the reasoning parser respects the enable_thinking flag. - + Returns ------- bool True if the reasoning parser respects the enable_thinking flag, False otherwise. """ - return True \ No newline at end of file + return True diff --git a/app/parsers/qwen3_moe.py b/app/parsers/qwen3_moe.py index a1e6da8c..ee0548fa 100644 --- a/app/parsers/qwen3_moe.py +++ b/app/parsers/qwen3_moe.py @@ -15,16 +15,18 @@ class Qwen3MoEReasoningParser(HermesReasoningParser): reasoning_content """ - def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None: + def __init__( + self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE + ) -> None: """Initialize the Qwen3 MoE reasoning parser with appropriate regex patterns.""" super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close) def needs_redacted_reasoning_prefix(self) -> bool: """Check if the reasoning parser needs a redacted reasoning prefix. - + Returns ------- bool True if the reasoning parser needs a redacted reasoning prefix, False otherwise. """ - return True \ No newline at end of file + return True diff --git a/app/parsers/solar_open.py b/app/parsers/solar_open.py index 01d5fcf6..a4463ef1 100644 --- a/app/parsers/solar_open.py +++ b/app/parsers/solar_open.py @@ -1,12 +1,12 @@ from __future__ import annotations -import json from enum import Enum +import json from loguru import logger -from .hermes import HermesReasoningParser from .abstract_parser import AbstractToolParser +from .hermes import HermesReasoningParser REASONING_OPEN = "<|think|>" REASONING_CLOSE = "<|end|>" @@ -30,18 +30,20 @@ class SolarOpenToolState(Enum): class SolarOpenReasoningParser(HermesReasoningParser): """Solar Open reasoning parser. - + Handles reasoning content in format: <|think|>reasoning_content<|end|> """ - def __init__(self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE) -> None: + def __init__( + self, reasoning_open: str = REASONING_OPEN, reasoning_close: str = REASONING_CLOSE + ) -> None: """Initialize Solar Open reasoning parser.""" super().__init__(reasoning_open=reasoning_open, reasoning_close=reasoning_close) class SolarOpenToolParser(AbstractToolParser): """Solar Open tool parser. - + Handles tool calls in format: <|tool_call:begin|><|tool_call:name|><|tool_call:args|><|tool_call:end|> """ @@ -56,12 +58,12 @@ def __init__(self, tool_open: str = TOOL_OPEN, tool_close: str = TOOL_CLOSE) -> def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: """Extract tool calls from complete model output. - + Parameters ---------- model_output : str Complete model output containing tool calls. - + Returns ------- dict[str, list] | None @@ -79,22 +81,12 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: while self.tool_open in remaining_output: tool_call_open_idx = remaining_output.find(self.tool_open) - tool_call_name_idx = remaining_output.find( - self.tool_name_prefix, tool_call_open_idx - ) - tool_call_args_idx = remaining_output.find( - self.tool_args_prefix, tool_call_name_idx - ) - tool_call_close_idx = remaining_output.find( - self.tool_close, tool_call_args_idx - ) + tool_call_name_idx = remaining_output.find(self.tool_name_prefix, tool_call_open_idx) + tool_call_args_idx = remaining_output.find(self.tool_args_prefix, tool_call_name_idx) + tool_call_close_idx = remaining_output.find(self.tool_close, tool_call_args_idx) # Validate all required tokens were found - if ( - tool_call_name_idx == -1 - or tool_call_args_idx == -1 - or tool_call_close_idx == -1 - ): + if tool_call_name_idx == -1 or tool_call_args_idx == -1 or tool_call_close_idx == -1: logger.warning( f"Malformed tool call in output, missing required tokens: {remaining_output[:100]}" ) @@ -113,31 +105,25 @@ def extract_tool_calls(self, model_output: str) -> dict[str, list] | None: json.loads(tool_args) # Validate JSON format tool_calls.append({"name": tool_name, "arguments": tool_args}) except json.JSONDecodeError as e: - logger.warning( - f"Invalid JSON in tool arguments for '{tool_name}': {e}" - ) + logger.warning(f"Invalid JSON in tool arguments for '{tool_name}': {e}") # Skip this malformed tool call and continue # Move past this tool call - remaining_output = remaining_output[ - tool_call_close_idx + len(self.tool_close) : - ] + remaining_output = remaining_output[tool_call_close_idx + len(self.tool_close) :] return { "tool_calls": tool_calls, "content": remaining_output if remaining_output else None, } - def extract_tool_calls_streaming( - self, chunk: str - ) -> tuple[dict[str, list] | str | None, bool]: + def extract_tool_calls_streaming(self, chunk: str) -> tuple[dict[str, list] | str | None, bool]: """Extract tool calls from streaming chunks. - + Parameters ---------- chunk : str Chunk of model output to process. - + Returns ------- tuple[dict[str, list] | str | None, bool] @@ -176,4 +162,4 @@ def extract_tool_calls_streaming( return None, False # Normal state - keep buffering - return None, False \ No newline at end of file + return None, False diff --git a/app/schemas/openai.py b/app/schemas/openai.py index 36a6086c..035d860d 100644 --- a/app/schemas/openai.py +++ b/app/schemas/openai.py @@ -2,18 +2,19 @@ from __future__ import annotations -import uuid -import time from enum import Enum +import time from typing import Any, ClassVar, Literal, TypeAlias +import uuid -from loguru import logger from fastapi import UploadFile +from loguru import logger from pydantic import BaseModel, ConfigDict, Field, model_validator + class OpenAIBaseModel(BaseModel): """Base model for OpenAI API schemas.""" - + # OpenAI API does allow extra fields model_config = ConfigDict(extra="allow") @@ -154,9 +155,10 @@ class PromptTokenUsageInfo(OpenAIBaseModel): cached_tokens: int | None = None + class StreamOptions(OpenAIBaseModel): """Stream options for a request.""" - + include_usage: bool | None = True continuous_usage_stats: bool | None = False @@ -178,15 +180,16 @@ class FunctionCall(OpenAIBaseModel): name: str arguments: str + def random_uuid() -> str: return str(uuid.uuid4().hex) + def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): if id_type == "kimi_k2": return f"functions.{func_name}:{idx}" - else: - # by default return random - return f"chatcmpl-tool-{random_uuid()}" + # by default return random + return f"chatcmpl-tool-{random_uuid()}" class ChatCompletionMessageToolCall(OpenAIBaseModel): @@ -224,6 +227,7 @@ class Message(OpenAIBaseModel): "from this message instead of starting a new assistant turn (prefill / partial mode).", ) + class ChatTemplateKwargs(OpenAIBaseModel): """Represents the arguments for a chat template.""" @@ -232,44 +236,45 @@ class ChatTemplateKwargs(OpenAIBaseModel): default="medium", description="The reasoning effort level." ) + class FunctionDefinition(OpenAIBaseModel): name: str description: str | None = None parameters: dict[str, Any] | None = None + class ChatCompletionToolsParam(OpenAIBaseModel): type: Literal["function"] = "function" function: FunctionDefinition + class ChatCompletionNamedFunction(OpenAIBaseModel): name: str + class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): function: ChatCompletionNamedFunction type: Literal["function"] = "function" -class ChatCompletionRequest(OpenAIBaseModel): +class ChatCompletionRequest(OpenAIBaseModel): """Request schema for OpenAI-compatible chat completion API.""" - + model: str = Field(Config.TEXT_MODEL, description="The model to use for completion.") messages: list[Message] = Field(..., description="The list of messages in the conversation.") tools: list[ChatCompletionToolsParam] | None = Field( None, description="List of tools available for the request." ) - tool_choice: ( - Literal["none"] - | Literal["auto"] - | Literal["required"] - | ChatCompletionNamedToolChoiceParam - | None - ) = "none" + tool_choice: Literal["none", "auto", "required"] | ChatCompletionNamedToolChoiceParam | None = ( + "none" + ) max_tokens: int | None = Field( default=None, - deprecated="max_tokens is deprecated in favor of " - "the max_completion_tokens field", + deprecated="max_tokens is deprecated in favor of the max_completion_tokens field", + ) + max_completion_tokens: int | None = Field( + None, description="Maximum number of tokens to generate." ) - max_completion_tokens: int | None = Field(None, description="Maximum number of tokens to generate.") temperature: float | None = Field(None, description="Sampling temperature.") top_p: float | None = Field(None, description="Nucleus sampling probability.") top_k: int | None = Field(None, description="Top-k sampling parameter.") @@ -284,7 +289,8 @@ class ChatCompletionRequest(OpenAIBaseModel): n: int | None = Field(None, description="Number of completions to generate.") response_format: dict[str, Any] | None = Field(None, description="Format for the response.") seed: int | None = Field( - None, description="The seed to use for sampling.", + None, + description="The seed to use for sampling.", ) user: str | None = Field(None, description="User identifier.") repetition_penalty: float | None = Field( @@ -296,11 +302,10 @@ class ChatCompletionRequest(OpenAIBaseModel): xtc_probability: float | None = Field( None, description="XTC (eXclude Top Choices) sampling probability (0.0-1.0)." ) - xtc_threshold: float | None = Field( - None, description="XTC sampling threshold (0.0-0.5)." - ) + xtc_threshold: float | None = Field(None, description="XTC sampling threshold (0.0-0.5).") logit_bias: dict[str, float] | None = Field( - None, description="Modify the likelihood of specified tokens appearing in the completion. Maps token IDs (as strings) to bias values from -100 to 100." + None, + description="Modify the likelihood of specified tokens appearing in the completion. Maps token IDs (as strings) to bias values from -100 to 100.", ) stream: bool = Field(False, description="Whether to stream the response.") stream_options: StreamOptions | None = None @@ -452,6 +457,7 @@ class ImageSize(str, Enum): MEDIUM = "512x512" LARGE = "1024x1024" + class Priority(str, Enum): """Task priority levels.""" @@ -485,13 +491,9 @@ class TranscriptionResponseFormat(str, Enum): class ImageGenerationRequest(OpenAIBaseModel): """Request schema for OpenAI-compatible image generation API.""" - prompt: str = Field( - ..., - description="A text description of the desired image(s)." - ) + prompt: str = Field(..., description="A text description of the desired image(s).") negative_prompt: str | None = Field( - None, - description="A text description of the desired image(s)." + None, description="A text description of the desired image(s)." ) model: str | None = Field( default=Config.IMAGE_GENERATION_MODEL, description="The model to use for image generation" @@ -546,7 +548,9 @@ class ImageGenerationError(OpenAIBaseModel): class ImageEditRequest(OpenAIBaseModel): """Request data for OpenAI-compatible image edit API.""" - image: UploadFile | list[UploadFile] = Field(..., description="The image(s) to edit. Must be a file upload or a list of file uploads") + image: UploadFile | list[UploadFile] = Field( + ..., description="The image(s) to edit. Must be a file upload or a list of file uploads" + ) prompt: str = Field(..., description="The prompt for the image edit") model: str | None = Field( default=Config.IMAGE_EDIT_MODEL, description="The model to use for image edit" @@ -650,14 +654,9 @@ class TranscriptionResponseStream(OpenAIBaseModel): # --- Responses API Schemas --- -from openai.types.responses import ( - ResponseStatus, - ResponseInputItemParam, - ResponseOutputItem, -) +from openai.types.responses import ResponseInputItemParam, ResponseOutputItem, ResponseStatus +from openai.types.responses.response import IncompleteDetails, Tool, ToolChoice from openai.types.shared import Reasoning -from openai.types.responses.response import Tool, ToolChoice -from openai.types.responses.response import IncompleteDetails ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem @@ -696,9 +695,7 @@ class ResponsesRequest(OpenAIBaseModel): ) seed: int | None = Field(None, description="The seed for the response.") text: ResponseTextConfig | None = None - tools: list[Tool] | None = Field( - None, description="List of tools to use for the response." - ) + tools: list[Tool] | None = Field(None, description="List of tools to use for the response.") tool_choice: ToolChoice | None = Field( default="auto", description="The tool choice to use for the response." ) @@ -717,13 +714,14 @@ class OutputTokensDetails(OpenAIBaseModel): output_tokens_per_turn: list[int] = Field(default_factory=list) tool_output_tokens_per_turn: list[int] = Field(default_factory=list) + class ResponseUsage(OpenAIBaseModel): input_tokens: int input_tokens_details: InputTokensDetails output_tokens: int output_tokens_details: OutputTokensDetails total_tokens: int - + class ResponsesResponse(OpenAIBaseModel): """Represents a complete response from the Responses API.""" diff --git a/app/server.py b/app/server.py index 4b40ac3c..b561b5ad 100644 --- a/app/server.py +++ b/app/server.py @@ -57,6 +57,7 @@ def ensure_image_handler_available(model_type: str) -> None: raise RuntimeError(MFLUX_INSTALL_HINT) + _SAMPLING_DEFAULT_FIELDS: tuple[str, ...] = ( "default_max_tokens", "default_temperature", diff --git a/app/utils/debug_logging.py b/app/utils/debug_logging.py index 106cfd4e..5797d597 100644 --- a/app/utils/debug_logging.py +++ b/app/utils/debug_logging.py @@ -2,12 +2,13 @@ import time from typing import Any + from loguru import logger def log_debug_request(request_dict: dict[str, Any]) -> None: """Log request details in a beautiful format for debug mode. - + Parameters ---------- request_dict : dict[str, Any] @@ -16,25 +17,27 @@ def log_debug_request(request_dict: dict[str, Any]) -> None: logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") logger.info("🔍 DEBUG: Request Details") logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - + # Extract and format key information if "messages" in request_dict: logger.info(f"📨 Messages: {len(request_dict['messages'])} message(s)") for i, msg in enumerate(request_dict["messages"], 1): role = msg.get("role", "unknown") content = msg.get("content", "") - content_preview = str(content)[:100] + "..." if len(str(content)) > 100 else str(content) + content_preview = ( + str(content)[:100] + "..." if len(str(content)) > 100 else str(content) + ) logger.info(f" {i}. [{role}] {content_preview}") - + if request_dict.get("max_tokens"): logger.info(f"🎯 Max Tokens: {request_dict['max_tokens']:,}") - + if request_dict.get("temperature"): logger.info(f"🌡️ Temperature: {request_dict['temperature']}") - + if request_dict.get("top_p"): logger.info(f"🎲 Top P: {request_dict['top_p']}") - + logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") @@ -46,7 +49,7 @@ def log_debug_stats( peak_memory: float, ) -> None: """Log generation statistics in a beautiful format for debug mode. - + Parameters ---------- prompt_tokens : int @@ -73,7 +76,7 @@ def log_debug_stats( def log_debug_prompt(prompt: str) -> None: """Log input prompt in a beautiful format for debug mode. - + Parameters ---------- prompt : str @@ -86,7 +89,7 @@ def log_debug_prompt(prompt: str) -> None: def log_debug_raw_text_response(raw_text: str) -> None: """Log raw text response in a beautiful format for debug mode. - + Parameters ---------- raw_text : str @@ -101,7 +104,7 @@ def log_debug_raw_text_response(raw_text: str) -> None: def log_debug_cache_stats(total_input_tokens: int, remaining_tokens: int) -> None: """Log prompt cache statistics in a beautiful format for debug mode. - + Parameters ---------- total_input_tokens : int @@ -155,14 +158,15 @@ def log_debug_chat_template( logger.info("✦ Using default chat template from model") logger.info("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + def make_prompt_progress_callback(start_time: float | None = None) -> callable: """Create a callback function for tracking prompt processing progress. - + Parameters ---------- start_time : float | None The start time for calculating speed. If None, uses current time. - + Returns ------- callable @@ -170,13 +174,13 @@ def make_prompt_progress_callback(start_time: float | None = None) -> callable: """ if start_time is None: start_time = time.time() - + def callback(processed: int, total_tokens: int) -> None: """Log prompt processing progress with speed metrics.""" elapsed = time.time() - start_time speed = processed / elapsed if elapsed > 0 else 0 logger.info(f"⚡ Processed {processed:6d}/{total_tokens} tokens ({speed:6.2f} tok/s)") - + return callback diff --git a/app/utils/prompt_cache.py b/app/utils/prompt_cache.py index e17a813a..c2966dbe 100644 --- a/app/utils/prompt_cache.py +++ b/app/utils/prompt_cache.py @@ -2,8 +2,8 @@ from __future__ import annotations -import copy from collections import deque +import copy from dataclasses import dataclass from typing import Any @@ -139,6 +139,7 @@ def _search(self, tokens_ids: list[int]) -> SearchResult: ---------- tokens_ids : list[int] Token sequence to search for. + Returns ------- SearchResult @@ -182,7 +183,6 @@ def _search(self, tokens_ids: list[int]) -> SearchResult: return self.SearchResult(None, shorter, longer, common_prefix) - def fetch_nearest_cache( self, tokens_ids: list[int], @@ -193,6 +193,7 @@ def fetch_nearest_cache( ---------- tokens_ids : list[int] Token sequence to find a cache for. + Returns ------- tuple[list[Any] | None, list[int]] @@ -227,6 +228,7 @@ def _get(self, tokens_ids: list[int]) -> CacheEntry: ---------- tokens_ids : list[int] Token sequence identifying the cache entry. + Returns ------- CacheEntry @@ -344,8 +346,10 @@ def log_cache_stats(self) -> None: latest_checkpoint_tokens, ) + if __name__ == "__main__": from app.models.mlx_lm import MLX_LM + model_path = "mlx-community/Qwen3-Coder-Next-8bit" draft_model_path = "mlx-community/Qwen3-Coder-Next-4bit" model = MLX_LM(model_path, draft_model_path) @@ -374,7 +378,6 @@ def log_cache_stats(self) -> None: first_token = False cache_key.append(chunk.token) - prompt_cache.insert_cache(cache_key, cache) start_time = time.time() @@ -383,7 +386,7 @@ def log_cache_stats(self) -> None: input_prompt_2 = model.create_input_prompt([{"role": "user", "content": prompt_2}], {}) input_ids_2 = model.encode_prompt(input_prompt_2) cache, rest_input_ids_2 = prompt_cache.fetch_nearest_cache(input_ids_2) - + if cache is None: cache = model.create_prompt_cache() # Use full input_ids for cache_key, not rest_input_ids @@ -402,4 +405,4 @@ def log_cache_stats(self) -> None: print("RAW TEXT", raw_text) - prompt_cache.insert_cache(cache_key_2, cache) \ No newline at end of file + prompt_cache.insert_cache(cache_key_2, cache) diff --git a/configure_mlx.sh b/configure_mlx.sh index f1cfe6e6..00e1bb80 100644 --- a/configure_mlx.sh +++ b/configure_mlx.sh @@ -40,4 +40,4 @@ else sudo sysctl -w iogpu.wired_limit_mb=$WIRED_LIMIT_MB sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB 2>/dev/null || \ sudo sysctl -w iogpu.wired_lwm_mb=$WIRED_LWM_MB -fi \ No newline at end of file +fi diff --git a/examples/chat_templates/Longcat-Flash-Lite.jinja b/examples/chat_templates/Longcat-Flash-Lite.jinja index 3166cb94..45b06416 100644 --- a/examples/chat_templates/Longcat-Flash-Lite.jinja +++ b/examples/chat_templates/Longcat-Flash-Lite.jinja @@ -65,4 +65,4 @@ {%- endfor %} {%- if add_generation_prompt %} {{- "" }} -{%- endif %} \ No newline at end of file +{%- endif %} diff --git a/examples/chat_templates/llama4.jinja b/examples/chat_templates/llama4.jinja index 386d0320..bbed3d82 100644 --- a/examples/chat_templates/llama4.jinja +++ b/examples/chat_templates/llama4.jinja @@ -108,4 +108,4 @@ {%- endfor %} {%- if add_generation_prompt %} {{- '<|header_start|>assistant<|header_end|>\n\n' }} -{%- endif %} \ No newline at end of file +{%- endif %} diff --git a/examples/config.yaml b/examples/config.yaml index 02d011ae..d30ab8bd 100644 --- a/examples/config.yaml +++ b/examples/config.yaml @@ -33,4 +33,4 @@ models: model_type: image-generation config_name: flux2-klein-4b quantize: 4 - model_id: flux2-klein-4b \ No newline at end of file + model_id: flux2-klein-4b diff --git a/pyproject.toml b/pyproject.toml index 86c8c167..5c3724e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,8 @@ dependencies = [ "mlx-lm>=0.31.0,<0.32", "mlx-vlm>=0.4.0,<0.5", "mlx-whisper>=0.4.3,<0.5", - "mlx>=0.31.0", - "numpy>=2.2.0,<2.4", + "mlx>=0.31.0", + "numpy>=2.2.0,<2.5", "openai>=2.21,<3", "openai-harmony>=0.0.8,<0.1", "outlines>=1.1.1,<1.2", @@ -323,4 +323,4 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = ["untyped_package.*"] -follow_untyped_imports = true \ No newline at end of file +follow_untyped_imports = true diff --git a/tests/parsers/test_harmony.py b/tests/parsers/test_harmony.py index f9f47ab5..c9fd8289 100644 --- a/tests/parsers/test_harmony.py +++ b/tests/parsers/test_harmony.py @@ -31,11 +31,11 @@ def test_harmony_reasoning_and_tool_parsing_streaming() -> None: "json", "<|message|>", "{", - "\"city\"", + '"city"', ":", - "\"Tokyo\"", + '"Tokyo"', "}", - "<|call|>" + "<|call|>", ] results = [] @@ -57,7 +57,7 @@ def test_harmony_reasoning_and_tool_parsing_streaming() -> None: assert "tool_calls" in final_result assert final_result["tool_calls"] is not None assert len(final_result["tool_calls"]) == 1 - + tool_call = final_result["tool_calls"][0] assert tool_call["name"] == "get_weather" assert "Tokyo" in tool_call["arguments"] @@ -79,7 +79,10 @@ def test_harmony_non_streaming_parse() -> None: assert result is not None assert "reasoning_content" in result assert result["reasoning_content"] is not None - assert "need" in result["reasoning_content"].lower() or "call" in result["reasoning_content"].lower() + assert ( + "need" in result["reasoning_content"].lower() + or "call" in result["reasoning_content"].lower() + ) # Verify tool calls assert "tool_calls" in result @@ -87,7 +90,8 @@ def test_harmony_non_streaming_parse() -> None: assert result["tool_calls"][0]["name"] == "get_weather" assert "Tokyo" in result["tool_calls"][0]["arguments"] + if __name__ == "__main__": test_harmony_reasoning_and_tool_parsing_streaming() test_harmony_non_streaming_parse() - print("All tests passed!") \ No newline at end of file + print("All tests passed!") diff --git a/tests/test_multi_model_per_model_defaults.py b/tests/test_multi_model_per_model_defaults.py index 67a71f6f..dd70cb50 100644 --- a/tests/test_multi_model_per_model_defaults.py +++ b/tests/test_multi_model_per_model_defaults.py @@ -103,9 +103,7 @@ def _make_raw_request(registry: _FakeRegistry, handler: Any | None = None) -> An """Build a minimal request-like object for endpoint unit tests.""" return types.SimpleNamespace( - app=types.SimpleNamespace( - state=types.SimpleNamespace(registry=registry, handler=handler) - ), + app=types.SimpleNamespace(state=types.SimpleNamespace(registry=registry, handler=handler)), state=types.SimpleNamespace(request_id="req-test"), ) @@ -335,7 +333,9 @@ async def _fake_process_text_request( await endpoints_module.chat_completions(request, _make_raw_request(registry)) - assert captured_requests[0].temperature == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"])) + assert captured_requests[0].temperature == pytest.approx( + float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"]) + ) assert captured_requests[0].top_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_P"])) assert captured_requests[0].top_k == int(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_K"]) assert captured_requests[0].min_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_MIN_P"])) @@ -343,7 +343,9 @@ async def _fake_process_text_request( float(GLOBAL_ENV_DEFAULTS["DEFAULT_REPETITION_PENALTY"]) ) assert captured_requests[0].seed == int(GLOBAL_ENV_DEFAULTS["DEFAULT_SEED"]) - assert captured_requests[0].max_completion_tokens == int(GLOBAL_ENV_DEFAULTS["DEFAULT_MAX_TOKENS"]) + assert captured_requests[0].max_completion_tokens == int( + GLOBAL_ENV_DEFAULTS["DEFAULT_MAX_TOKENS"] + ) assert captured_requests[0].xtc_probability == pytest.approx( float(GLOBAL_ENV_DEFAULTS["DEFAULT_XTC_PROBABILITY"]) ) @@ -509,7 +511,9 @@ async def test_responses_single_model_implicit_handler_defaults_do_not_shadow_en refined_request = endpoints_module.refine_responses_request(request, handler) - assert refined_request.temperature == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"])) + assert refined_request.temperature == pytest.approx( + float(GLOBAL_ENV_DEFAULTS["DEFAULT_TEMPERATURE"]) + ) assert refined_request.top_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_P"])) assert refined_request.top_k == int(GLOBAL_ENV_DEFAULTS["DEFAULT_TOP_K"]) assert refined_request.min_p == pytest.approx(float(GLOBAL_ENV_DEFAULTS["DEFAULT_MIN_P"])) diff --git a/tests/test_prompt_cache_cancellation.py b/tests/test_prompt_cache_cancellation.py index c214580a..1dabbb63 100644 --- a/tests/test_prompt_cache_cancellation.py +++ b/tests/test_prompt_cache_cancellation.py @@ -87,7 +87,9 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]: chunk.token = 4 yield chunk - mock_inference_worker.submit_stream = Mock(side_effect=lambda *args, **kwargs: mock_response_gen()) + mock_inference_worker.submit_stream = Mock( + side_effect=lambda *args, **kwargs: mock_response_gen() + ) mock_prompt_cache = Mock() mock_prompt_cache.fetch_nearest_cache.return_value = (Mock(name="cache_obj"), []) @@ -115,7 +117,7 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]: mock_parsers_result.is_unified = False mock_parsers_result.reasoning_parser = None mock_parsers_result.tool_parser = None - with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result): + with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result): fake_request = Mock() gen = handler.generate_text_stream(fake_request) try: @@ -158,7 +160,9 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]: chunk.generation_tokens = 20 yield chunk - mock_inference_worker.submit_stream = Mock(side_effect=lambda *args, **kwargs: mock_response_gen()) + mock_inference_worker.submit_stream = Mock( + side_effect=lambda *args, **kwargs: mock_response_gen() + ) mock_prompt_cache = Mock() mock_prompt_cache.fetch_nearest_cache.return_value = (Mock(name="cache_obj"), []) @@ -184,7 +188,7 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]: mock_parsers_result.is_unified = False mock_parsers_result.reasoning_parser = None mock_parsers_result.tool_parser = None - with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result): + with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result): fake_request = Mock() gen = handler.generate_text_stream(fake_request) try: @@ -246,7 +250,7 @@ async def blocked_response_gen() -> AsyncGenerator[None, None]: mock_parsers_result.is_unified = False mock_parsers_result.reasoning_parser = None mock_parsers_result.tool_parser = None - with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result): + with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result): fake_request = Mock() gen = handler.generate_text_stream(fake_request) task = asyncio.create_task(gen.__anext__()) @@ -284,7 +288,9 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]: chunk.token = token yield chunk - mock_inference_worker.submit_stream = Mock(side_effect=lambda *args, **kwargs: mock_response_gen()) + mock_inference_worker.submit_stream = Mock( + side_effect=lambda *args, **kwargs: mock_response_gen() + ) mock_prompt_cache = Mock() mock_prompt_cache.fetch_nearest_cache.return_value = (Mock(name="cache_obj"), []) @@ -310,7 +316,7 @@ async def mock_response_gen() -> AsyncGenerator[Mock, None]: mock_parsers_result.is_unified = False mock_parsers_result.reasoning_parser = None mock_parsers_result.tool_parser = None - with patch('app.handler.mlx_lm.ParserManager.create_parsers', return_value=mock_parsers_result): + with patch("app.handler.mlx_lm.ParserManager.create_parsers", return_value=mock_parsers_result): fake_request = Mock() gen = handler.generate_text_stream(fake_request) try: diff --git a/tests/test_responses_request_conversion.py b/tests/test_responses_request_conversion.py index 945819d9..a6fecf0a 100644 --- a/tests/test_responses_request_conversion.py +++ b/tests/test_responses_request_conversion.py @@ -133,7 +133,7 @@ def test_convert_responses_stream_history_skips_reasoning_but_keeps_tool_turns() "status": "completed", "call_id": "call_123", "name": "du", - "arguments": "{\"path\":\"/tmp\"}", + "arguments": '{"path":"/tmp"}', }, { "type": "function_call_output", @@ -180,10 +180,12 @@ def test_convert_responses_stream_history_skips_reasoning_but_keeps_tool_turns() "user", ] - tool_call_messages = [m for m in chat_request.messages if m.role == "assistant" and m.tool_calls] + tool_call_messages = [ + m for m in chat_request.messages if m.role == "assistant" and m.tool_calls + ] assert len(tool_call_messages) == 1 assert tool_call_messages[0].tool_calls[0].function.name == "du" - assert tool_call_messages[0].tool_calls[0].function.arguments == "{\"path\":\"/tmp\"}" + assert tool_call_messages[0].tool_calls[0].function.arguments == '{"path":"/tmp"}' assert tool_call_messages[0].tool_calls[0].id == "call_123" tool_output_messages = [m for m in chat_request.messages if m.role == "tool"]