diff --git a/src/tests/test_request_auth_headers.py b/src/tests/test_request_auth_headers.py index 1ea6b19a2..2382918fb 100644 --- a/src/tests/test_request_auth_headers.py +++ b/src/tests/test_request_auth_headers.py @@ -33,7 +33,7 @@ def _build_request(headers=None): state.callbacks = None state.request_stats_monitor.get_request_stats.return_value = {} state.engine_stats_scraper.get_engine_stats.return_value = {} - state.router.route_request.return_value = "http://whisper-engine" + state.router.route_request = AsyncMock(return_value="http://whisper-engine") request = MagicMock() request.headers = headers or {} diff --git a/src/vllm_router/services/request_service/rewriter.py b/src/vllm_router/services/request_service/rewriter.py index 0f4263c6d..d108842a9 100644 --- a/src/vllm_router/services/request_service/rewriter.py +++ b/src/vllm_router/services/request_service/rewriter.py @@ -19,6 +19,7 @@ """ import abc +import json from vllm_router.log import init_logger from vllm_router.utils import SingletonABCMeta @@ -70,6 +71,92 @@ def rewrite_request(self, request_body: str, model: str, endpoint: str) -> str: return request_body +class MessagesRewriter(RequestRewriter): + """ + A request rewriter for Anthropic Messages API and OpenAI Chat Completions API + requests that normalizes messages before forwarding to the backend. + + Normalizations: + - Filters out messages with empty/null content (some backends reject them). + - For ``/v1/messages``, promotes ``role: "system"`` entries in the messages + array to the top-level ``system`` parameter (handles the ``mid-conversation-system`` + beta format sent by e.g. Claude Code). + """ + + def rewrite_request(self, request_body: str, model: str, endpoint: str) -> str: + try: + body = json.loads(request_body) + except json.JSONDecodeError: + return request_body + + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return request_body + + # Guard: skip messages with empty content (some backends reject them). + messages = [m for m in messages if _message_has_content(m)] + + if not messages: + return request_body + + # For Anthropic Messages API, also promote role: "system" to top-level system param. + if endpoint == "/v1/messages": + system_messages = [m for m in messages if m.get("role") == "system"] + if system_messages: + system_content = _join_system_content(system_messages) + body["messages"] = [m for m in messages if m.get("role") != "system"] + if body.get("system") is not None: + existing = body["system"] + if isinstance(existing, str): + body["system"] = existing + "\n" + system_content + elif isinstance(existing, list): + body["system"].append({"type": "text", "text": system_content}) + else: + body["system"] = system_content + else: + body["system"] = system_content + + logger.info( + "Promoted %d system message(s) from messages array to top-level system param", + len(system_messages), + ) + return json.dumps(body) + + body["messages"] = messages + return json.dumps(body) + + # For chat completions, just apply the empty-content guard. + if endpoint in ("/v1/chat/completions", "/chat/completions"): + body["messages"] = messages + return json.dumps(body) + + return request_body + + +def _message_has_content(message: dict) -> bool: + content = message.get("content") + if content is None: + return False + if isinstance(content, str): + return content.strip() != "" + if isinstance(content, list): + return len(content) > 0 + return bool(content) + + +def _join_system_content(system_messages: list[dict]) -> str: + parts = [] + for msg in system_messages: + content = msg.get("content") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + parts.append(block.get("text", "")) + return "\n".join(parts) + + # Singleton instance _request_rewriter_instance = None @@ -87,10 +174,12 @@ def initialize_request_rewriter(rewriter_type: str, **kwargs) -> RequestRewriter """ global _request_rewriter_instance - # TODO: Implement different rewriter types - # For now, just use the NoopRequestRewriter - _request_rewriter_instance = NoopRequestRewriter() - logger.info(f"Initialized placeholder request rewriter (type: {rewriter_type})") + if rewriter_type == "messages": + _request_rewriter_instance = MessagesRewriter() + logger.info("Initialized MessagesRewriter") + else: + _request_rewriter_instance = NoopRequestRewriter() + logger.info(f"Initialized placeholder request rewriter (type: {rewriter_type})") return _request_rewriter_instance @@ -111,9 +200,10 @@ def get_request_rewriter() -> RequestRewriter: Get the request rewriter singleton instance. Returns: - The request rewriter instance or NoopRequestRewriter if not initialized + The request rewriter instance or MessagesRewriter if not initialized """ global _request_rewriter_instance if _request_rewriter_instance is None: - _request_rewriter_instance = NoopRequestRewriter() + _request_rewriter_instance = MessagesRewriter() + logger.info("Initialized default MessagesRewriter") return _request_rewriter_instance