Skip to content

Commit 17cf823

Browse files
authored
fix: preserve Ollama streaming tool calls (#1948)
Closes #1922. Ollama can emit `tool_calls` before the final `done=true` chunk when streaming. The adapter was only reading `chunk.message.tool_calls` from the final chunk, so a real streamed tool call could turn into an empty final ADK response. I changed the streaming path to accumulate tool calls across chunks, then include them when building the final `LlmResponse`. I also moved the Ollama tool-call to ADK part conversion into one helper so streaming and non-streaming stay consistent. I verified this two ways: - Added a regression test where the tool call arrives on a non-final chunk and the done chunk has no tool calls. This failed before the fix and passes after it. - Reproduced against a live local Ollama server with `qwen3:0.6b`. Raw Ollama emitted `get_temperature({"city": "Tokyo"})` on `done=false`, then a final `done=true` chunk with no tool calls. Before the fix, kagent returned `function_calls=0`; after the fix, kagent returned the expected `get_temperature` function call. Checks run: - `uv run pytest packages/kagent-adk/tests/unittests/models/test_ollama.py -q` - `uv run pytest packages/kagent-adk/tests/unittests/models --ignore=packages/kagent-adk/tests/unittests/models/test_tls_e2e.py -q` - `uv run ruff check packages/kagent-adk/src/kagent/adk/models/_ollama.py packages/kagent-adk/tests/unittests/models/test_ollama.py` One note: the full model test directory has unrelated local TLS E2E failures because the cert fixture files are missing in my checkout. The non-TLS model tests passed. Signed-off-by: andreivince <andreivince21@gmail.com>
1 parent 32ce1fb commit 17cf823

2 files changed

Lines changed: 56 additions & 11 deletions

File tree

python/packages/kagent-adk/src/kagent/adk/models/_ollama.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def _convert_tools_to_ollama(tools: list[types.Tool]) -> list[ollama_sdk.Tool]:
133133
return ollama_tools
134134

135135

136+
def _convert_tool_call_to_part(tc: OllamaMessage.ToolCall) -> types.Part:
137+
part = types.Part.from_function_call(name=tc.function.name, args=dict(tc.function.arguments))
138+
if part.function_call:
139+
part.function_call.id = str(uuid.uuid4())
140+
return part
141+
142+
136143
class KAgentOllamaLlm(KAgentTLSMixin, BaseLlm):
137144
"""Ollama model via the native Ollama SDK.
138145
@@ -190,6 +197,7 @@ async def generate_content_async(
190197
try:
191198
if stream:
192199
aggregated_text = ""
200+
tool_calls = []
193201
response: AsyncIterator[ollama_sdk.ChatResponse] = await self._client.chat(
194202
model=llm_request.model or self.model,
195203
messages=messages,
@@ -198,6 +206,7 @@ async def generate_content_async(
198206
stream=True,
199207
)
200208
async for chunk in response:
209+
tool_calls.extend(chunk.message.tool_calls or [])
201210
if chunk.message.content:
202211
aggregated_text += chunk.message.content
203212
yield LlmResponse(
@@ -211,13 +220,7 @@ async def generate_content_async(
211220
final_parts = []
212221
if aggregated_text:
213222
final_parts.append(types.Part.from_text(text=aggregated_text))
214-
for tc in chunk.message.tool_calls or []:
215-
part = types.Part.from_function_call(
216-
name=tc.function.name, args=dict(tc.function.arguments)
217-
)
218-
if part.function_call:
219-
part.function_call.id = str(uuid.uuid4())
220-
final_parts.append(part)
223+
final_parts.extend(_convert_tool_call_to_part(tc) for tc in tool_calls)
221224
finish_reason = _done_reason_to_finish_reason(chunk.done_reason) if chunk.done_reason else None
222225
usage_metadata = None
223226
if chunk.prompt_eval_count is not None or chunk.eval_count is not None:
@@ -245,10 +248,7 @@ async def generate_content_async(
245248
if response.message.content:
246249
parts.append(types.Part.from_text(text=response.message.content))
247250
for tc in response.message.tool_calls or []:
248-
part = types.Part.from_function_call(name=tc.function.name, args=dict(tc.function.arguments))
249-
if part.function_call:
250-
part.function_call.id = str(uuid.uuid4())
251-
parts.append(part)
251+
parts.append(_convert_tool_call_to_part(tc))
252252
finish_reason = _done_reason_to_finish_reason(response.done_reason) if response.done_reason else None
253253
usage_metadata = None
254254
if response.prompt_eval_count is not None or response.eval_count is not None:

python/packages/kagent-adk/tests/unittests/models/test_ollama.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,51 @@ async def test_generate_content_forwards_ollama_options(self):
9292

9393
assert mock_client.chat.call_args.kwargs["options"] == opts
9494

95+
@pytest.mark.asyncio
96+
async def test_generate_content_streaming_accumulates_tool_calls_before_done_chunk(self):
97+
llm = KAgentOllamaLlm(model="llama3.2:latest")
98+
99+
tool_call = mock.MagicMock()
100+
tool_call.function.name = "get_weather"
101+
tool_call.function.arguments = {"city": "Tokyo"}
102+
103+
tool_chunk = mock.MagicMock()
104+
tool_chunk.message.content = ""
105+
tool_chunk.message.tool_calls = [tool_call]
106+
tool_chunk.done = False
107+
108+
done_chunk = mock.MagicMock()
109+
done_chunk.message.content = ""
110+
done_chunk.message.tool_calls = None
111+
done_chunk.done = True
112+
done_chunk.done_reason = "stop"
113+
done_chunk.prompt_eval_count = 10
114+
done_chunk.eval_count = 0
115+
116+
async def chunks():
117+
yield tool_chunk
118+
yield done_chunk
119+
120+
mock_client = mock.AsyncMock()
121+
mock_client.chat = mock.AsyncMock(return_value=chunks())
122+
123+
request = mock.MagicMock()
124+
request.model = "llama3.2:latest"
125+
request.contents = []
126+
request.config = None
127+
128+
with mock.patch.object(type(llm), "_client", new_callable=lambda: property(lambda self: mock_client)):
129+
responses = [r async for r in llm.generate_content_async(request, stream=True)]
130+
131+
assert len(responses) == 1
132+
final_response = responses[0]
133+
assert final_response.partial is False
134+
assert final_response.turn_complete is True
135+
assert len(final_response.content.parts) == 1
136+
function_call = final_response.content.parts[0].function_call
137+
assert function_call.name == "get_weather"
138+
assert dict(function_call.args) == {"city": "Tokyo"}
139+
95140

96141
class TestConvertContentToOllamaMessages:
97142
def test_image_inline_data_included(self):

0 commit comments

Comments
 (0)