Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tests/test_request_auth_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _build_request(headers=None):
state.callbacks = None
state.request_stats_monitor.get_request_stats.return_value = {}
state.engine_stats_scraper.get_engine_stats.return_value = {}
state.router.route_request.return_value = "http://whisper-engine"
state.router.route_request = AsyncMock(return_value="http://whisper-engine")

request = MagicMock()
request.headers = headers or {}
Expand Down
102 changes: 96 additions & 6 deletions src/vllm_router/services/request_service/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import abc
import json

from vllm_router.log import init_logger
from vllm_router.utils import SingletonABCMeta
Expand Down Expand Up @@ -70,6 +71,92 @@ def rewrite_request(self, request_body: str, model: str, endpoint: str) -> str:
return request_body


class MessagesRewriter(RequestRewriter):
"""
A request rewriter for Anthropic Messages API and OpenAI Chat Completions API
requests that normalizes messages before forwarding to the backend.

Normalizations:
- Filters out messages with empty/null content (some backends reject them).
- For ``/v1/messages``, promotes ``role: "system"`` entries in the messages
array to the top-level ``system`` parameter (handles the ``mid-conversation-system``
beta format sent by e.g. Claude Code).
"""

def rewrite_request(self, request_body: str, model: str, endpoint: str) -> str:
try:
body = json.loads(request_body)
except json.JSONDecodeError:
return request_body

messages = body.get("messages")
if not messages or not isinstance(messages, list):
return request_body

# Guard: skip messages with empty content (some backends reject them).
messages = [m for m in messages if _message_has_content(m)]

if not messages:
return request_body

# For Anthropic Messages API, also promote role: "system" to top-level system param.
if endpoint == "/v1/messages":
system_messages = [m for m in messages if m.get("role") == "system"]
if system_messages:
system_content = _join_system_content(system_messages)
body["messages"] = [m for m in messages if m.get("role") != "system"]
if body.get("system") is not None:
existing = body["system"]
if isinstance(existing, str):
body["system"] = existing + "\n" + system_content
elif isinstance(existing, list):
body["system"].append({"type": "text", "text": system_content})
else:
body["system"] = system_content
else:
body["system"] = system_content

logger.info(
"Promoted %d system message(s) from messages array to top-level system param",
len(system_messages),
)
return json.dumps(body)

body["messages"] = messages
return json.dumps(body)

# For chat completions, just apply the empty-content guard.
if endpoint in ("/v1/chat/completions", "/chat/completions"):
body["messages"] = messages
return json.dumps(body)

return request_body


def _message_has_content(message: dict) -> bool:
content = message.get("content")
if content is None:
return False
if isinstance(content, str):
return content.strip() != ""
if isinstance(content, list):
return len(content) > 0
return bool(content)


def _join_system_content(system_messages: list[dict]) -> str:
parts = []
for msg in system_messages:
content = msg.get("content")
if isinstance(content, str):
parts.append(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("text", ""))
return "\n".join(parts)


# Singleton instance
_request_rewriter_instance = None

Expand All @@ -87,10 +174,12 @@ def initialize_request_rewriter(rewriter_type: str, **kwargs) -> RequestRewriter
"""
global _request_rewriter_instance

# TODO: Implement different rewriter types
# For now, just use the NoopRequestRewriter
_request_rewriter_instance = NoopRequestRewriter()
logger.info(f"Initialized placeholder request rewriter (type: {rewriter_type})")
if rewriter_type == "messages":
_request_rewriter_instance = MessagesRewriter()
logger.info("Initialized MessagesRewriter")
else:
_request_rewriter_instance = NoopRequestRewriter()
logger.info(f"Initialized placeholder request rewriter (type: {rewriter_type})")

return _request_rewriter_instance

Expand All @@ -111,9 +200,10 @@ def get_request_rewriter() -> RequestRewriter:
Get the request rewriter singleton instance.

Returns:
The request rewriter instance or NoopRequestRewriter if not initialized
The request rewriter instance or MessagesRewriter if not initialized
"""
global _request_rewriter_instance
if _request_rewriter_instance is None:
_request_rewriter_instance = NoopRequestRewriter()
_request_rewriter_instance = MessagesRewriter()
logger.info("Initialized default MessagesRewriter")
return _request_rewriter_instance
Loading