Skip to content

Commit ca18818

Browse files
authored
feat: add tool calling support to m serve (#850)
* feat: add tool calling support to m serve Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: fixed the bug in m serve where finish_reason=tool_calls for empty dict Fixed the bug where an empty tool_calls dict ({}) incorrectly produced finish_reason="tool_calls" with an empty array instead of finish_reason="stop" with tool_calls=None. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: move message add to outside the loop in client_tool_calling.py example Issue: The assistant message was being added inside the loop for each tool call, causing duplication when multiple tool calls were present. Fix: Moved the assistant message append outside the loop (before processing tool calls), so it's only added once. Now the loop only adds tool responses. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: cli app.py loop variable tool_name is never used The dict key tool_name is never used — the function name comes from model_tool_call.name. Using .values() instead. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: fix test_mot_init_typing() hasattr was always true Replaced hasattr() with direct __dict__ membership tests to correctly distinguish: 1. Typed instances (ModelOutputThunk[float](...)) - have __orig_class__ in their instance dict 2. Untyped instances (ModelOutputThunk(...)) - do NOT have __orig_class__ in their instance dict Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: update m_serve_example_tool_calling.py to use safer example tool Security issue resolved in `m_serve_example_tool_calling.py`: **Changes made:** - Replaced `CalculatorTool` (which used unsafe `eval()` with `# noqa: S307`) with `GetStockPriceTool` - New tool demonstrates API-calling pattern with mock stock prices (AAPL, GOOGL, MSFT, TSLA) - Updated all references: `calculator_tool` → `stock_price_tool` - Maintains the same tool calling demonstration with two tools (weather + stock price) **Why this is better:** - Eliminates security risk entirely (no `eval()` or suppressed lints) - Still demonstrates multiple tools effectively - Uses safe, realistic API-calling pattern that users can copy - No dangerous code that could be copy-pasted into production Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: replace repeated hard-coded string with constant Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: add TOOL_CHOICE to ModelOptions like TEMPERATURE not a sentinel The pass-thru behavior was not clear enough, so adding it to ModelOptions where important options are known. Most of these are sentinels which are removed (because @@@) but this will be like TEMPERATURE which is passed through to the backends. No behavior change, but give a handly constant and a place to look for these. This does not address all the other possible pass through args. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> Assisted-by: IBM Bob * fix: fix m serve tool-calling examples - switch server example to OpenAIBackend - align tool-calling example with tested Granite model setup - narrow advertised tools when `tool_choice` selects a specific function - enable `tool_calls=True` in the serve path - replace calculator example with stock-price tool - examples 1/2 as tool-call-only demos - example 4 as the full tool execution round-trip - improve client diagnostics for empty/no-tool responses Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> Assisted-by: IBM Bob * fix: remove unused imports in example Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * feat: cli support for OpenAI API tool calling with streaming Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> Assisted-by: IBM Bob * fix: add required index field to streaming tool call deltas The OpenAI streaming spec requires each item in delta.tool_calls to carry an index field. Clients including the openai Python SDK, LangChain, and LiteLLM key their delta-reassembly state machine on this field. Without it, they silently drop tool calls, coalesce them incorrectly, or raise a TypeError depending on version. Changes: - Add ChatCompletionMessageToolCallDelta model with required index field - Add ToolCallFunctionDelta model for streaming function deltas - Update ChatCompletionChunkDelta to use delta models - Update streaming.py to populate index field using enumerate() - Add comprehensive tests verifying index field presence - Update existing test to check for index field The bundled client_streaming_tool_calling.py example masked this issue because it reads delta.tool_calls verbatim rather than going through SDK delta reassembly. Fixes compatibility with OpenAI SDK, LangChain, and LiteLLM streaming tool call consumers. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> Assisted-by: IBM Bob * fix: move build_tool_calls invocation build_tool_calls was called before streaming block and then not used in case of streaming. Rearrange condition and call to avoid wasted call. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * test: add integration test for cli/serve using TestClient with streaming and tool calling Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * fix: use fallback for json.dumps in build_tool_calls Use str to non-serializable types. This should effectively avoid TypeError (in normal situations). Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * test: restore cli streaming tests to fix conflicts New tests for tooling improved coverage, but the significant rewrite caused too much diverging from main. Keeping the old tests in places while adding new tests in new file will help sort this out. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * test: update output.usage -> output.generation.usage Rebased and now the new tests need updating. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * test: update tests usage -> gneration.usage More new tests need fixing after rebase. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> * refactor(serve): simplify tool call construction with Pydantic validation - Use model_validate() instead of manual field mapping for tool calls - Move uuid import to module level in openai_compatible_helpers - Replace manual async function with AsyncMock in streaming error test - Remove redundant comments about tool call extraction These changes reduce code duplication and leverage Pydantic's built-in validation for cleaner, more maintainable code. Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> Assisted-by: IBM Bob * fix: remove unused imports Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com> --------- Signed-off-by: Mark Sturdevant <mark.sturdevant@ibm.com>
1 parent bf8a8ad commit ca18818

15 files changed

Lines changed: 2767 additions & 21 deletions

cli/serve/app.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import time
99
import uuid
10+
from typing import Literal
1011

1112
try:
1213
import typer
@@ -21,11 +22,15 @@
2122
) from e
2223

2324
from mellea.backends.model_options import ModelOption
24-
from mellea.helpers.openai_compatible_helpers import build_completion_usage
25+
from mellea.helpers.openai_compatible_helpers import (
26+
build_completion_usage,
27+
build_tool_calls,
28+
)
2529

2630
from .models import (
2731
ChatCompletion,
2832
ChatCompletionMessage,
33+
ChatCompletionMessageToolCall,
2934
ChatCompletionRequest,
3035
Choice,
3136
OpenAIError,
@@ -111,14 +116,14 @@ def _build_model_options(request: ChatCompletionRequest) -> dict:
111116
"response_format", # Response format (json_object) - not yet implemented
112117
"functions", # Legacy function calling - not yet implemented
113118
"function_call", # Legacy function calling - not yet implemented
114-
"tools", # Tool calling - not yet implemented
115-
"tool_choice", # Tool choice - not yet implemented
116119
}
117120
openai_to_model_option = {
118121
"temperature": ModelOption.TEMPERATURE,
119122
"max_tokens": ModelOption.MAX_NEW_TOKENS,
120123
"seed": ModelOption.SEED,
121124
"stream": ModelOption.STREAM,
125+
"tools": ModelOption.TOOLS,
126+
"tool_choice": ModelOption.TOOL_CHOICE,
122127
}
123128

124129
# Get all non-None fields
@@ -171,8 +176,6 @@ async def endpoint(request: ChatCompletionRequest):
171176
model_options=model_options,
172177
)
173178

174-
# system_fingerprint represents backend config hash, not model name
175-
# The model name is already in response.model (line 73)
176179
# Leave as None since we don't track backend config fingerprints yet
177180
system_fingerprint = None
178181

@@ -190,6 +193,24 @@ async def endpoint(request: ChatCompletionRequest):
190193
media_type="text/event-stream",
191194
)
192195

196+
tool_calls_list = build_tool_calls(output)
197+
tool_calls = (
198+
[
199+
ChatCompletionMessageToolCall.model_validate(tc)
200+
for tc in tool_calls_list
201+
]
202+
if tool_calls_list
203+
else None
204+
)
205+
206+
# Determine finish_reason based on tool calls
207+
finish_reason: (
208+
Literal[
209+
"stop", "length", "content_filter", "tool_calls", "function_call"
210+
]
211+
| None
212+
) = "tool_calls" if tool_calls else "stop"
213+
193214
return ChatCompletion(
194215
id=completion_id,
195216
model=request.model,
@@ -198,9 +219,11 @@ async def endpoint(request: ChatCompletionRequest):
198219
Choice(
199220
index=0,
200221
message=ChatCompletionMessage(
201-
content=output.value, role="assistant"
222+
content=output.value,
223+
role="assistant",
224+
tool_calls=tool_calls,
202225
),
203-
finish_reason="stop",
226+
finish_reason=finish_reason,
204227
)
205228
],
206229
object="chat.completion", # type: ignore

cli/serve/models.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,67 @@ class ChatCompletionRequest(BaseModel):
8080
extra: dict[str, Any] = Field(default_factory=dict)
8181

8282

83+
class ToolCallFunction(BaseModel):
84+
"""Function details for a tool call."""
85+
86+
name: str
87+
"""The name of the function to call."""
88+
89+
arguments: str
90+
"""The arguments to call the function with, as a JSON string."""
91+
92+
93+
class ChatCompletionMessageToolCall(BaseModel):
94+
"""A tool call generated by the model (non-streaming)."""
95+
96+
id: str
97+
"""The ID of the tool call."""
98+
99+
type: Literal["function"]
100+
"""The type of the tool. Currently, only 'function' is supported."""
101+
102+
function: ToolCallFunction
103+
"""The function that the model called."""
104+
105+
106+
class ToolCallFunctionDelta(BaseModel):
107+
"""Function details for a streaming tool call delta.
108+
109+
In streaming responses, function name and arguments may arrive across
110+
multiple chunks, so both fields are optional.
111+
"""
112+
113+
name: str | None = None
114+
"""The name of the function to call (may be None in delta chunks)."""
115+
116+
arguments: str | None = None
117+
"""The arguments fragment for this delta (may be None in delta chunks)."""
118+
119+
120+
class ChatCompletionMessageToolCallDelta(BaseModel):
121+
"""A tool call delta in a streaming response.
122+
123+
Per OpenAI streaming spec, each delta must include an index field that
124+
clients use to reassemble tool calls across chunks. The id, type, and
125+
function fields are optional since they may arrive incrementally.
126+
"""
127+
128+
index: int
129+
"""The index of this tool call in the tool_calls array.
130+
131+
Required for delta reassembly in OpenAI SDK and compatible clients.
132+
"""
133+
134+
id: str | None = None
135+
"""The ID of the tool call (may be None in subsequent delta chunks)."""
136+
137+
type: Literal["function"] | None = None
138+
"""The type of the tool (may be None in subsequent delta chunks)."""
139+
140+
function: ToolCallFunctionDelta | None = None
141+
"""The function delta for this chunk (may be None in some chunks)."""
142+
143+
83144
# Taking this from OpenAI types https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py,
84145
class ChatCompletionMessage(BaseModel):
85146
content: str | None = None
@@ -91,6 +152,9 @@ class ChatCompletionMessage(BaseModel):
91152
role: Literal["assistant"]
92153
"""The role of the author of this message."""
93154

155+
tool_calls: list[ChatCompletionMessageToolCall] | None = None
156+
"""The tool calls generated by the model, such as function calls."""
157+
94158

95159
class Choice(BaseModel):
96160
index: int
@@ -144,6 +208,14 @@ class ChatCompletionChunkDelta(BaseModel):
144208
refusal: str | None = None
145209
"""The refusal message fragment, if any."""
146210

211+
tool_calls: list[ChatCompletionMessageToolCallDelta] | None = None
212+
"""The tool call deltas in this chunk.
213+
214+
Each delta includes a required index field for reassembly by OpenAI SDK
215+
and compatible clients. The id, type, and function fields are optional
216+
since they may arrive incrementally across multiple chunks.
217+
"""
218+
147219

148220
class ChatCompletionChunkChoice(BaseModel):
149221
"""A choice in a streaming chunk."""

cli/serve/streaming.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
"""Streaming utilities for OpenAI-compatible server responses."""
22

33
from collections.abc import AsyncGenerator
4+
from typing import Literal
45

56
from mellea.core.base import ModelOutputThunk
67
from mellea.core.utils import MelleaLogger
7-
from mellea.helpers.openai_compatible_helpers import build_completion_usage
8+
from mellea.helpers.openai_compatible_helpers import (
9+
build_completion_usage,
10+
build_tool_calls,
11+
)
812

913
from .models import (
1014
ChatCompletionChunk,
1115
ChatCompletionChunkChoice,
1216
ChatCompletionChunkDelta,
17+
ChatCompletionMessageToolCallDelta,
1318
OpenAIError,
1419
OpenAIErrorResponse,
1520
StreamOptions,
@@ -98,6 +103,38 @@ async def stream_chat_completion_chunks(
98103
)
99104
yield f"data: {chunk.model_dump_json()}\n\n"
100105

106+
tool_calls_list = build_tool_calls(output)
107+
108+
if tool_calls_list:
109+
# Convert to ChatCompletionMessageToolCallDelta objects with required index
110+
tool_calls = [
111+
ChatCompletionMessageToolCallDelta.model_validate({**tc, "index": idx})
112+
for idx, tc in enumerate(tool_calls_list)
113+
]
114+
115+
# Emit tool calls in a separate chunk before the final chunk
116+
tool_call_chunk = ChatCompletionChunk(
117+
id=completion_id,
118+
model=model,
119+
created=created,
120+
choices=[
121+
ChatCompletionChunkChoice(
122+
index=0,
123+
delta=ChatCompletionChunkDelta(tool_calls=tool_calls),
124+
finish_reason=None,
125+
)
126+
],
127+
object="chat.completion.chunk",
128+
system_fingerprint=system_fingerprint,
129+
)
130+
yield f"data: {tool_call_chunk.model_dump_json()}\n\n"
131+
132+
# Determine finish_reason based on tool calls
133+
finish_reason: (
134+
Literal["stop", "length", "content_filter", "tool_calls", "function_call"]
135+
| None
136+
) = "tool_calls" if tool_calls_list else "stop"
137+
101138
# Include usage in final chunk only if explicitly requested via stream_options
102139
# Per OpenAI spec: usage is only included when stream_options.include_usage=True
103140
include_usage = stream_options is not None and stream_options.include_usage
@@ -112,7 +149,7 @@ async def stream_chat_completion_chunks(
112149
ChatCompletionChunkChoice(
113150
index=0,
114151
delta=ChatCompletionChunkDelta(content=None),
115-
finish_reason="stop",
152+
finish_reason=finish_reason,
116153
)
117154
],
118155
object="chat.completion.chunk",

0 commit comments

Comments
 (0)