Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions cli/serve/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, RootModel, model_validator

from mellea.helpers.openai_compatible_helpers import CompletionUsage

Expand All @@ -13,9 +13,28 @@ class ChatMessage(BaseModel):
function_call: dict[str, Any] | None = None # For function/tool messages


class FunctionParameters(BaseModel):
# Accept any structure for function parameters
RootModel: dict[str, Any]
class FunctionParameters(RootModel[dict[str, Any]]):
"""OpenAI-compatible function parameters as a bare JSON Schema object.

Accepts a standard JSON Schema dict directly without wrapping.
Example: {"type": "object", "properties": {...}, "required": [...]}
"""

root: dict[str, Any]
Comment thread
planetf1 marked this conversation as resolved.

@model_validator(mode="after")
def _reject_legacy_envelope(self) -> "FunctionParameters":
"""Reject legacy RootModel envelope pattern.

Ensures parameters are sent as a bare JSON Schema object, not wrapped
in a {"RootModel": {...}} envelope which would be invalid.
"""
if set(self.root.keys()) == {"RootModel"}:
raise ValueError(
"Legacy {'RootModel': {...}} envelope is no longer accepted. "
"Send parameters as a bare JSON Schema object."
)
return self


class FunctionDefinition(BaseModel):
Expand All @@ -41,7 +60,7 @@ class JsonSchemaFormat(BaseModel):
strict: bool | None = None
"""Accepted for OpenAI compatibility; currently ignored by ``m serve``."""

model_config = {"populate_by_name": True}
model_config = {"populate_by_name": True, "serialize_by_alias": True}


class ResponseFormat(BaseModel):
Expand Down Expand Up @@ -73,10 +92,6 @@ class StreamOptions(BaseModel):
"""


class LogitBias(BaseModel):
RootModel: dict[str, float]


class ChatCompletionRequest(BaseModel):
model_config = {"extra": "allow"}

Expand Down
48 changes: 23 additions & 25 deletions docs/examples/m_serve/client_streaming_tool_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,19 @@
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"RootModel": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name, e.g. San Francisco",
},
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature units",
},
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name, e.g. San Francisco",
},
"required": ["location"],
}
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature units",
},
},
"required": ["location"],
},
},
},
Expand All @@ -52,16 +50,14 @@
"name": "get_stock_price",
"description": "Get the current stock price for a given ticker symbol",
"parameters": {
"RootModel": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
}
},
"required": ["symbol"],
}
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
}
},
"required": ["symbol"],
},
},
},
Expand Down Expand Up @@ -232,7 +228,9 @@ def main():
# Example 4: Multi-turn conversation with tool use
print("\n\n4. Multi-turn Conversation (Streaming)")
print("-" * 60)
messages = [{"role": "user", "content": "What's the weather in Paris?"}]
messages: list[dict[str, Any]] = [
{"role": "user", "content": "What's the weather in Paris?"}
]

print(f"User: {messages[0]['content']}")
print("\nAssistant: ", end="", flush=True)
Expand Down
44 changes: 20 additions & 24 deletions docs/examples/m_serve/client_tool_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,19 @@
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"RootModel": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name, e.g. San Francisco",
},
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature units",
},
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city name, e.g. San Francisco",
},
"required": ["location"],
}
"units": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "Temperature units",
},
},
"required": ["location"],
},
},
},
Expand All @@ -51,16 +49,14 @@
"name": "get_stock_price",
"description": "Get the current stock price for a given ticker symbol",
"parameters": {
"RootModel": {
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
}
},
"required": ["symbol"],
}
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
}
},
"required": ["symbol"],
},
},
},
Expand Down
4 changes: 2 additions & 2 deletions test/cli/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ async def test_tool_params_passed_to_model_options(self, mock_module):
function=FunctionDefinition(
name="test_func",
description="A test function",
parameters=FunctionParameters(RootModel={"type": "object"}),
parameters=FunctionParameters({"type": "object"}),
),
)
],
Expand All @@ -471,7 +471,7 @@ async def test_tool_params_passed_to_model_options(self, mock_module):
FunctionDefinition(
name="legacy_func",
description="A legacy function",
parameters=FunctionParameters(RootModel={"type": "object"}),
parameters=FunctionParameters({"type": "object"}),
)
],
function_call="auto",
Expand Down
42 changes: 41 additions & 1 deletion test/cli/test_serve_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def as_tool_function(self) -> ToolFunction:
name="get_weather",
description="Get the current weather in a location",
parameters=FunctionParameters(
RootModel={
{
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"},
Expand Down Expand Up @@ -571,3 +571,43 @@ def test_server_error_returns_500(self, client, mock_module):
assert data["error"]["type"] == "server_error"
assert "Internal error" not in data["error"]["message"]
assert "Internal server error" in data["error"]["message"]

def test_legacy_root_model_envelope_rejected_via_http(self, client, mock_module):
"""Test that legacy {'RootModel': {...}} envelope is rejected at HTTP layer.

Verifies that the FunctionParameters validator catches the legacy envelope
pattern and returns a proper 400 error via the HTTP API.
"""
# Send request with legacy envelope in function parameters
response = client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "What's the weather?"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"RootModel": {
"type": "object",
"properties": {"location": {"type": "string"}},
}
},
},
}
],
},
)

# Should return 400 with validation error
assert response.status_code == 400
data = response.json()
assert "error" in data
assert data["error"]["type"] == "invalid_request_error"
assert (
"Legacy {'RootModel': {...}} envelope is no longer accepted"
in data["error"]["message"]
)
77 changes: 76 additions & 1 deletion test/cli/test_serve_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from pydantic import ValidationError

from cli.serve.models import StreamOptions
from cli.serve.models import FunctionParameters, JsonSchemaFormat, StreamOptions


class TestStreamOptions:
Expand Down Expand Up @@ -79,3 +79,78 @@ def test_model_dump_json_serialization(self):
json_str = options.model_dump_json()
assert "include_usage" in json_str
assert "true" in json_str.lower()


class TestFunctionParameters:
"""Tests for the FunctionParameters RootModel validator."""

def test_valid_json_schema_accepted(self):
"""Test that a valid JSON Schema dict is accepted."""
schema = {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
}
params = FunctionParameters(root=schema)
assert params.root == schema

def test_legacy_root_model_envelope_rejected(self):
"""Test that legacy {'RootModel': {...}} envelope is rejected."""
legacy_envelope = {
"RootModel": {
"type": "object",
"properties": {"location": {"type": "string"}},
}
}
with pytest.raises(ValidationError) as exc_info:
FunctionParameters(root=legacy_envelope)

errors = exc_info.value.errors()
assert len(errors) == 1
error_msg = str(exc_info.value)
assert "Legacy {'RootModel': {...}} envelope is no longer accepted" in error_msg

def test_root_model_with_additional_keys_accepted(self):
"""Test that a dict with 'RootModel' plus other keys is accepted."""
# This is a valid schema that happens to have a property named "RootModel"
schema = {
"type": "object",
"properties": {
"RootModel": {"type": "string"},
"other_field": {"type": "number"},
},
}
params = FunctionParameters(root=schema)
assert params.root == schema

def test_empty_dict_accepted(self):
"""Test that an empty dict is accepted (though not a useful schema)."""
params = FunctionParameters(root={})
assert params.root == {}


class TestJsonSchemaFormat:
"""Test JsonSchemaFormat serialization uses 'schema' alias, not 'schema_'."""

def test_serialization_uses_schema_alias(self):
"""Verify schema_ serializes as 'schema' in dict and JSON output."""
schema_def = {"type": "object", "properties": {"foo": {"type": "string"}}}
json_schema = JsonSchemaFormat(name="TestSchema", schema=schema_def)

# Dict serialization
dumped = json_schema.model_dump()
assert "schema" in dumped and "schema_" not in dumped
assert dumped["schema"] == schema_def

# JSON serialization
json_str = json_schema.model_dump_json()
assert '"schema":' in json_str and '"schema_":' not in json_str

# Input accepts both 'schema' (alias) and 'schema_' (field name)
from_alias = JsonSchemaFormat(name="Test1", schema={"type": "string"})
# Use model_validate to test runtime populate_by_name behavior (bypasses type checker)
from_field = JsonSchemaFormat.model_validate(
{"name": "Test2", "schema_": {"type": "number"}}
)
assert from_alias.schema_ == {"type": "string"}
assert from_field.schema_ == {"type": "number"}
Loading
Loading