diff --git a/cli/serve/models.py b/cli/serve/models.py index 7e50638b8..a775e4c0a 100644 --- a/cli/serve/models.py +++ b/cli/serve/models.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, RootModel, model_validator from mellea.helpers.openai_compatible_helpers import CompletionUsage @@ -13,9 +13,28 @@ class ChatMessage(BaseModel): function_call: dict[str, Any] | None = None # For function/tool messages -class FunctionParameters(BaseModel): - # Accept any structure for function parameters - RootModel: dict[str, Any] +class FunctionParameters(RootModel[dict[str, Any]]): + """OpenAI-compatible function parameters as a bare JSON Schema object. + + Accepts a standard JSON Schema dict directly without wrapping. + Example: {"type": "object", "properties": {...}, "required": [...]} + """ + + root: dict[str, Any] + + @model_validator(mode="after") + def _reject_legacy_envelope(self) -> "FunctionParameters": + """Reject legacy RootModel envelope pattern. + + Ensures parameters are sent as a bare JSON Schema object, not wrapped + in a {"RootModel": {...}} envelope which would be invalid. + """ + if set(self.root.keys()) == {"RootModel"}: + raise ValueError( + "Legacy {'RootModel': {...}} envelope is no longer accepted. " + "Send parameters as a bare JSON Schema object." + ) + return self class FunctionDefinition(BaseModel): @@ -41,7 +60,7 @@ class JsonSchemaFormat(BaseModel): strict: bool | None = None """Accepted for OpenAI compatibility; currently ignored by ``m serve``.""" - model_config = {"populate_by_name": True} + model_config = {"populate_by_name": True, "serialize_by_alias": True} class ResponseFormat(BaseModel): @@ -73,10 +92,6 @@ class StreamOptions(BaseModel): """ -class LogitBias(BaseModel): - RootModel: dict[str, float] - - class ChatCompletionRequest(BaseModel): model_config = {"extra": "allow"} diff --git a/docs/examples/m_serve/client_streaming_tool_calling.py b/docs/examples/m_serve/client_streaming_tool_calling.py index 2a406564e..6e57c67a4 100644 --- a/docs/examples/m_serve/client_streaming_tool_calling.py +++ b/docs/examples/m_serve/client_streaming_tool_calling.py @@ -28,21 +28,19 @@ "name": "get_weather", "description": "Get the current weather in a given location", "parameters": { - "RootModel": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city name, e.g. San Francisco", - }, - "units": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "Temperature units", - }, + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", }, - "required": ["location"], - } + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], }, }, }, @@ -52,16 +50,14 @@ "name": "get_stock_price", "description": "Get the current stock price for a given ticker symbol", "parameters": { - "RootModel": { - "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock ticker symbol, e.g. AAPL, GOOGL", - } - }, - "required": ["symbol"], - } + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], }, }, }, @@ -232,7 +228,9 @@ def main(): # Example 4: Multi-turn conversation with tool use print("\n\n4. Multi-turn Conversation (Streaming)") print("-" * 60) - messages = [{"role": "user", "content": "What's the weather in Paris?"}] + messages: list[dict[str, Any]] = [ + {"role": "user", "content": "What's the weather in Paris?"} + ] print(f"User: {messages[0]['content']}") print("\nAssistant: ", end="", flush=True) diff --git a/docs/examples/m_serve/client_tool_calling.py b/docs/examples/m_serve/client_tool_calling.py index d68e5d238..c5b78e490 100644 --- a/docs/examples/m_serve/client_tool_calling.py +++ b/docs/examples/m_serve/client_tool_calling.py @@ -27,21 +27,19 @@ "name": "get_weather", "description": "Get the current weather in a given location", "parameters": { - "RootModel": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city name, e.g. San Francisco", - }, - "units": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "Temperature units", - }, + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name, e.g. San Francisco", }, - "required": ["location"], - } + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], }, }, }, @@ -51,16 +49,14 @@ "name": "get_stock_price", "description": "Get the current stock price for a given ticker symbol", "parameters": { - "RootModel": { - "type": "object", - "properties": { - "symbol": { - "type": "string", - "description": "The stock ticker symbol, e.g. AAPL, GOOGL", - } - }, - "required": ["symbol"], - } + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The stock ticker symbol, e.g. AAPL, GOOGL", + } + }, + "required": ["symbol"], }, }, }, diff --git a/test/cli/test_serve.py b/test/cli/test_serve.py index 85081a40b..889049613 100644 --- a/test/cli/test_serve.py +++ b/test/cli/test_serve.py @@ -462,7 +462,7 @@ async def test_tool_params_passed_to_model_options(self, mock_module): function=FunctionDefinition( name="test_func", description="A test function", - parameters=FunctionParameters(RootModel={"type": "object"}), + parameters=FunctionParameters({"type": "object"}), ), ) ], @@ -471,7 +471,7 @@ async def test_tool_params_passed_to_model_options(self, mock_module): FunctionDefinition( name="legacy_func", description="A legacy function", - parameters=FunctionParameters(RootModel={"type": "object"}), + parameters=FunctionParameters({"type": "object"}), ) ], function_call="auto", diff --git a/test/cli/test_serve_integration.py b/test/cli/test_serve_integration.py index 666accbc2..35805a3d9 100644 --- a/test/cli/test_serve_integration.py +++ b/test/cli/test_serve_integration.py @@ -64,7 +64,7 @@ def as_tool_function(self) -> ToolFunction: name="get_weather", description="Get the current weather in a location", parameters=FunctionParameters( - RootModel={ + { "type": "object", "properties": { "location": {"type": "string", "description": "City name"}, @@ -571,3 +571,43 @@ def test_server_error_returns_500(self, client, mock_module): assert data["error"]["type"] == "server_error" assert "Internal error" not in data["error"]["message"] assert "Internal server error" in data["error"]["message"] + + def test_legacy_root_model_envelope_rejected_via_http(self, client, mock_module): + """Test that legacy {'RootModel': {...}} envelope is rejected at HTTP layer. + + Verifies that the FunctionParameters validator catches the legacy envelope + pattern and returns a proper 400 error via the HTTP API. + """ + # Send request with legacy envelope in function parameters + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "What's the weather?"}], + "tools": [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "RootModel": { + "type": "object", + "properties": {"location": {"type": "string"}}, + } + }, + }, + } + ], + }, + ) + + # Should return 400 with validation error + assert response.status_code == 400 + data = response.json() + assert "error" in data + assert data["error"]["type"] == "invalid_request_error" + assert ( + "Legacy {'RootModel': {...}} envelope is no longer accepted" + in data["error"]["message"] + ) diff --git a/test/cli/test_serve_models.py b/test/cli/test_serve_models.py index 5cf0e2183..f8f14ee36 100644 --- a/test/cli/test_serve_models.py +++ b/test/cli/test_serve_models.py @@ -3,7 +3,7 @@ import pytest from pydantic import ValidationError -from cli.serve.models import StreamOptions +from cli.serve.models import FunctionParameters, JsonSchemaFormat, StreamOptions class TestStreamOptions: @@ -79,3 +79,78 @@ def test_model_dump_json_serialization(self): json_str = options.model_dump_json() assert "include_usage" in json_str assert "true" in json_str.lower() + + +class TestFunctionParameters: + """Tests for the FunctionParameters RootModel validator.""" + + def test_valid_json_schema_accepted(self): + """Test that a valid JSON Schema dict is accepted.""" + schema = { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + } + params = FunctionParameters(root=schema) + assert params.root == schema + + def test_legacy_root_model_envelope_rejected(self): + """Test that legacy {'RootModel': {...}} envelope is rejected.""" + legacy_envelope = { + "RootModel": { + "type": "object", + "properties": {"location": {"type": "string"}}, + } + } + with pytest.raises(ValidationError) as exc_info: + FunctionParameters(root=legacy_envelope) + + errors = exc_info.value.errors() + assert len(errors) == 1 + error_msg = str(exc_info.value) + assert "Legacy {'RootModel': {...}} envelope is no longer accepted" in error_msg + + def test_root_model_with_additional_keys_accepted(self): + """Test that a dict with 'RootModel' plus other keys is accepted.""" + # This is a valid schema that happens to have a property named "RootModel" + schema = { + "type": "object", + "properties": { + "RootModel": {"type": "string"}, + "other_field": {"type": "number"}, + }, + } + params = FunctionParameters(root=schema) + assert params.root == schema + + def test_empty_dict_accepted(self): + """Test that an empty dict is accepted (though not a useful schema).""" + params = FunctionParameters(root={}) + assert params.root == {} + + +class TestJsonSchemaFormat: + """Test JsonSchemaFormat serialization uses 'schema' alias, not 'schema_'.""" + + def test_serialization_uses_schema_alias(self): + """Verify schema_ serializes as 'schema' in dict and JSON output.""" + schema_def = {"type": "object", "properties": {"foo": {"type": "string"}}} + json_schema = JsonSchemaFormat(name="TestSchema", schema=schema_def) + + # Dict serialization + dumped = json_schema.model_dump() + assert "schema" in dumped and "schema_" not in dumped + assert dumped["schema"] == schema_def + + # JSON serialization + json_str = json_schema.model_dump_json() + assert '"schema":' in json_str and '"schema_":' not in json_str + + # Input accepts both 'schema' (alias) and 'schema_' (field name) + from_alias = JsonSchemaFormat(name="Test1", schema={"type": "string"}) + # Use model_validate to test runtime populate_by_name behavior (bypasses type checker) + from_field = JsonSchemaFormat.model_validate( + {"name": "Test2", "schema_": {"type": "number"}} + ) + assert from_alias.schema_ == {"type": "string"} + assert from_field.schema_ == {"type": "number"} diff --git a/test/cli/test_serve_tool_calling.py b/test/cli/test_serve_tool_calling.py index 29c5bbf1b..4f092b41b 100644 --- a/test/cli/test_serve_tool_calling.py +++ b/test/cli/test_serve_tool_calling.py @@ -66,7 +66,7 @@ def sample_tool_request(): name="get_weather", description="Get the current weather in a location", parameters=FunctionParameters( - RootModel={ + { "type": "object", "properties": { "location": { @@ -228,6 +228,66 @@ async def test_tool_choice_passed_to_model_options( assert ModelOption.TOOL_CHOICE in model_options assert model_options[ModelOption.TOOL_CHOICE] == "auto" + @pytest.mark.asyncio + async def test_standard_json_schema_tools_passed_to_model_options( + self, mock_module + ): + """Test that standard OpenAI function.parameters shape is preserved.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ChatMessage(role="user", content="What's the weather in Paris?")], + tools=[ + ToolFunction( + type="function", + function=FunctionDefinition( + name="get_weather", + description="Get the current weather in a location", + parameters=FunctionParameters( + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name", + }, + "units": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature units", + }, + }, + "required": ["location"], + } + ), + ), + ) + ], + tool_choice={"type": "function", "function": {"name": "get_weather"}}, + ) + mock_output = ModelOutputThunk("Test response") + mock_module.serve.return_value = mock_output + + endpoint = make_chat_endpoint(mock_module) + await endpoint(request) + + call_args = mock_module.serve.call_args + assert call_args is not None + model_options = call_args.kwargs["model_options"] + assert ModelOption.TOOLS in model_options + + tool_payload = model_options[ModelOption.TOOLS][0] + assert tool_payload["function"]["name"] == "get_weather" + assert tool_payload["function"]["parameters"]["type"] == "object" + assert ( + tool_payload["function"]["parameters"]["properties"]["location"]["type"] + == "string" + ) + assert tool_payload["function"]["parameters"]["required"] == ["location"] + assert model_options[ModelOption.TOOL_CHOICE] == { + "type": "function", + "function": {"name": "get_weather"}, + } + @pytest.mark.asyncio async def test_tool_calls_with_complex_arguments( self, mock_module, sample_tool_request