Skip to content

Commit 64c3451

Browse files
committed
test(cli): update tests and example removing RootModel
Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent e75dda3 commit 64c3451

5 files changed

Lines changed: 106 additions & 52 deletions

File tree

cli/serve/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any, Literal
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, RootModel, model_validator
44

55
from mellea.helpers.openai_compatible_helpers import CompletionUsage
66

docs/examples/m_serve/client_streaming_tool_calling.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,19 @@
2828
"name": "get_weather",
2929
"description": "Get the current weather in a given location",
3030
"parameters": {
31-
"RootModel": {
32-
"type": "object",
33-
"properties": {
34-
"location": {
35-
"type": "string",
36-
"description": "The city name, e.g. San Francisco",
37-
},
38-
"units": {
39-
"type": "string",
40-
"enum": ["celsius", "fahrenheit"],
41-
"description": "Temperature units",
42-
},
31+
"type": "object",
32+
"properties": {
33+
"location": {
34+
"type": "string",
35+
"description": "The city name, e.g. San Francisco",
4336
},
44-
"required": ["location"],
45-
}
37+
"units": {
38+
"type": "string",
39+
"enum": ["celsius", "fahrenheit"],
40+
"description": "Temperature units",
41+
},
42+
},
43+
"required": ["location"],
4644
},
4745
},
4846
},
@@ -52,16 +50,14 @@
5250
"name": "get_stock_price",
5351
"description": "Get the current stock price for a given ticker symbol",
5452
"parameters": {
55-
"RootModel": {
56-
"type": "object",
57-
"properties": {
58-
"symbol": {
59-
"type": "string",
60-
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
61-
}
62-
},
63-
"required": ["symbol"],
64-
}
53+
"type": "object",
54+
"properties": {
55+
"symbol": {
56+
"type": "string",
57+
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
58+
}
59+
},
60+
"required": ["symbol"],
6561
},
6662
},
6763
},
@@ -232,7 +228,9 @@ def main():
232228
# Example 4: Multi-turn conversation with tool use
233229
print("\n\n4. Multi-turn Conversation (Streaming)")
234230
print("-" * 60)
235-
messages = [{"role": "user", "content": "What's the weather in Paris?"}]
231+
messages: list[dict[str, Any]] = [
232+
{"role": "user", "content": "What's the weather in Paris?"}
233+
]
236234

237235
print(f"User: {messages[0]['content']}")
238236
print("\nAssistant: ", end="", flush=True)

docs/examples/m_serve/client_tool_calling.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,19 @@
2727
"name": "get_weather",
2828
"description": "Get the current weather in a given location",
2929
"parameters": {
30-
"RootModel": {
31-
"type": "object",
32-
"properties": {
33-
"location": {
34-
"type": "string",
35-
"description": "The city name, e.g. San Francisco",
36-
},
37-
"units": {
38-
"type": "string",
39-
"enum": ["celsius", "fahrenheit"],
40-
"description": "Temperature units",
41-
},
30+
"type": "object",
31+
"properties": {
32+
"location": {
33+
"type": "string",
34+
"description": "The city name, e.g. San Francisco",
4235
},
43-
"required": ["location"],
44-
}
36+
"units": {
37+
"type": "string",
38+
"enum": ["celsius", "fahrenheit"],
39+
"description": "Temperature units",
40+
},
41+
},
42+
"required": ["location"],
4543
},
4644
},
4745
},
@@ -51,16 +49,14 @@
5149
"name": "get_stock_price",
5250
"description": "Get the current stock price for a given ticker symbol",
5351
"parameters": {
54-
"RootModel": {
55-
"type": "object",
56-
"properties": {
57-
"symbol": {
58-
"type": "string",
59-
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
60-
}
61-
},
62-
"required": ["symbol"],
63-
}
52+
"type": "object",
53+
"properties": {
54+
"symbol": {
55+
"type": "string",
56+
"description": "The stock ticker symbol, e.g. AAPL, GOOGL",
57+
}
58+
},
59+
"required": ["symbol"],
6460
},
6561
},
6662
},

test/cli/test_serve_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def as_tool_function(self) -> ToolFunction:
6464
name="get_weather",
6565
description="Get the current weather in a location",
6666
parameters=FunctionParameters(
67-
RootModel={
67+
{
6868
"type": "object",
6969
"properties": {
7070
"location": {"type": "string", "description": "City name"},

test/cli/test_serve_tool_calling.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def sample_tool_request():
6666
name="get_weather",
6767
description="Get the current weather in a location",
6868
parameters=FunctionParameters(
69-
RootModel={
69+
{
7070
"type": "object",
7171
"properties": {
7272
"location": {
@@ -228,6 +228,66 @@ async def test_tool_choice_passed_to_model_options(
228228
assert ModelOption.TOOL_CHOICE in model_options
229229
assert model_options[ModelOption.TOOL_CHOICE] == "auto"
230230

231+
@pytest.mark.asyncio
232+
async def test_standard_json_schema_tools_passed_to_model_options(
233+
self, mock_module
234+
):
235+
"""Test that standard OpenAI function.parameters shape is preserved."""
236+
request = ChatCompletionRequest(
237+
model="test-model",
238+
messages=[ChatMessage(role="user", content="What's the weather in Paris?")],
239+
tools=[
240+
ToolFunction(
241+
type="function",
242+
function=FunctionDefinition(
243+
name="get_weather",
244+
description="Get the current weather in a location",
245+
parameters=FunctionParameters(
246+
{
247+
"type": "object",
248+
"properties": {
249+
"location": {
250+
"type": "string",
251+
"description": "The city name",
252+
},
253+
"units": {
254+
"type": "string",
255+
"enum": ["celsius", "fahrenheit"],
256+
"description": "Temperature units",
257+
},
258+
},
259+
"required": ["location"],
260+
}
261+
),
262+
),
263+
)
264+
],
265+
tool_choice={"type": "function", "function": {"name": "get_weather"}},
266+
)
267+
mock_output = ModelOutputThunk("Test response")
268+
mock_module.serve.return_value = mock_output
269+
270+
endpoint = make_chat_endpoint(mock_module)
271+
await endpoint(request)
272+
273+
call_args = mock_module.serve.call_args
274+
assert call_args is not None
275+
model_options = call_args.kwargs["model_options"]
276+
assert ModelOption.TOOLS in model_options
277+
278+
tool_payload = model_options[ModelOption.TOOLS][0]
279+
assert tool_payload["function"]["name"] == "get_weather"
280+
assert tool_payload["function"]["parameters"]["type"] == "object"
281+
assert (
282+
tool_payload["function"]["parameters"]["properties"]["location"]["type"]
283+
== "string"
284+
)
285+
assert tool_payload["function"]["parameters"]["required"] == ["location"]
286+
assert model_options[ModelOption.TOOL_CHOICE] == {
287+
"type": "function",
288+
"function": {"name": "get_weather"},
289+
}
290+
231291
@pytest.mark.asyncio
232292
async def test_tool_calls_with_complex_arguments(
233293
self, mock_module, sample_tool_request

0 commit comments

Comments
 (0)