Skip to content

Commit c051029

Browse files
fix(ollama): preserve tool call IDs to fix repeated same-tool calls (#3321)
1 parent 60aa718 commit c051029

2 files changed

Lines changed: 208 additions & 42 deletions

File tree

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

Lines changed: 34 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def _convert_ollama_response_to_chatmessage(ollama_response: ChatResponse) -> Ch
168168
for ollama_tc in ollama_tool_calls:
169169
tool_calls.append(
170170
ToolCall(
171+
id=ollama_tc.get("id"),
171172
tool_name=ollama_tc["function"]["name"],
172173
arguments=ollama_tc["function"]["arguments"],
173174
)
@@ -208,6 +209,7 @@ def _build_chunk(
208209
tool_calls_list.append(
209210
ToolCallDelta(
210211
index=tool_call_index,
212+
id=tool_call.get("id"),
211213
tool_name=tool_call["function"]["name"],
212214
arguments=json.dumps(tool_call["function"]["arguments"])
213215
if tool_call["function"]["arguments"]
@@ -370,10 +372,11 @@ def _handle_streaming_response(
370372
component_info = ComponentInfo.from_component(self)
371373
chunks: list[StreamingChunk] = []
372374

373-
# Accumulators
374-
arg_by_id: dict[str, str] = {}
375-
name_by_id: dict[str, str] = {}
376-
id_order: list[str] = []
375+
# Accumulators keyed by tool_call.index (always unique per call, even for repeated tool names)
376+
arg_by_index: dict[str, str] = {}
377+
name_by_index: dict[str, str] = {}
378+
id_by_index: dict[str, str | None] = {}
379+
index_order: list[str] = []
377380
tool_call_index: int = 0
378381

379382
# track reasoning and content blocks to correctly set start=True on the first chunk of each block
@@ -400,29 +403,14 @@ def _handle_streaming_response(
400403

401404
if chunk.tool_calls:
402405
for tool_call in chunk.tool_calls:
403-
# the Ollama server doesn't guarantee an id field in every tool_calls entry.
404-
# OpenAI-compatible endpoint (/v1/chat/completions) - recent releases do add an auto-generated id
405-
# when the model produces multiple tool calls, so that clients can map results back.
406-
# Native Ollama endpoint (/api/chat) and older builds
407-
# - the JSON often contains only function.name + arguments;
408-
# many users have reported that id is missing even with several calls,
409-
# making client-side resolution harder:
410-
# https://github.com/ollama/ollama/issues/6708
411-
# https://github.com/ollama/ollama/issues/7510
412-
# - If id is provided → we can distinguish multiple calls to the same tool.
413-
414-
# - If id is missing → fallback to function.name works only when there's one call.
415-
# - That's why the deduplication logic is cautious and assumes one logical
416-
# call per name when id is absent.
417-
tool_call_id = tool_call.id or tool_call.tool_name or ""
406+
key = str(tool_call.index)
418407
args = tool_call.arguments or ""
419408

420-
# Remember first-seen order and tool name
421-
if tool_call_id not in id_order:
422-
id_order.append(tool_call_id)
423-
name_by_id[tool_call_id] = tool_call.tool_name or ""
424-
# Update the argument accumulator for this tool_call_id.
425-
arg_by_id[tool_call_id] = args
409+
if key not in index_order:
410+
index_order.append(key)
411+
name_by_index[key] = tool_call.tool_name or ""
412+
id_by_index[key] = tool_call.id
413+
arg_by_index[key] = args
426414

427415
if callback:
428416
callback(chunk)
@@ -435,9 +423,11 @@ def _handle_streaming_response(
435423
reasoning += c.reasoning.reasoning_text if c.reasoning else ""
436424

437425
tool_calls = []
438-
for tool_call_id in id_order:
439-
arguments: str = arg_by_id.get(tool_call_id, "")
440-
tool_calls.append(ToolCall(tool_name=name_by_id[tool_call_id], arguments=json.loads(arguments)))
426+
for key in index_order:
427+
arguments: str = arg_by_index.get(key, "")
428+
tool_calls.append(
429+
ToolCall(id=id_by_index[key], tool_name=name_by_index[key], arguments=json.loads(arguments))
430+
)
441431

442432
# We can't use _convert_streaming_chunks_to_chat_message because
443433
# we need to map tool_call name and args by order.
@@ -463,10 +453,11 @@ async def _handle_streaming_response_async(
463453
component_info = ComponentInfo.from_component(self)
464454
chunks: list[StreamingChunk] = []
465455

466-
# Accumulators
467-
arg_by_id: dict[str, str] = {}
468-
name_by_id: dict[str, str] = {}
469-
id_order: list[str] = []
456+
# Accumulators keyed by tool_call.index (always unique per call, even for repeated tool names)
457+
arg_by_index: dict[str, str] = {}
458+
name_by_index: dict[str, str] = {}
459+
id_by_index: dict[str, str | None] = {}
460+
index_order: list[str] = []
470461
tool_call_index: int = 0
471462

472463
# track reasoning and content blocks to correctly set start=True on the first chunk of each block
@@ -494,15 +485,14 @@ async def _handle_streaming_response_async(
494485

495486
if chunk.tool_calls:
496487
for tool_call in chunk.tool_calls:
497-
tool_call_id = tool_call.id or tool_call.tool_name or ""
488+
key = str(tool_call.index)
498489
args = tool_call.arguments or ""
499490

500-
# Remember first-seen order and tool name
501-
if tool_call_id not in id_order:
502-
id_order.append(tool_call_id)
503-
name_by_id[tool_call_id] = tool_call.tool_name or ""
504-
# Update the argument accumulator for this tool_call_id
505-
arg_by_id[tool_call_id] = args
491+
if key not in index_order:
492+
index_order.append(key)
493+
name_by_index[key] = tool_call.tool_name or ""
494+
id_by_index[key] = tool_call.id
495+
arg_by_index[key] = args
506496

507497
if callback is not None:
508498
await callback(chunk)
@@ -517,9 +507,11 @@ async def _handle_streaming_response_async(
517507
reasoning += c.reasoning.reasoning_text if c.reasoning else ""
518508

519509
tool_calls = []
520-
for tool_call_id in id_order:
521-
arguments: str = arg_by_id.get(tool_call_id, "")
522-
tool_calls.append(ToolCall(tool_name=name_by_id[tool_call_id], arguments=json.loads(arguments)))
510+
for key in index_order:
511+
arguments: str = arg_by_index.get(key, "")
512+
tool_calls.append(
513+
ToolCall(id=id_by_index[key], tool_name=name_by_index[key], arguments=json.loads(arguments))
514+
)
523515

524516
# We can't use _convert_streaming_chunks_to_chat_message because
525517
# we need to map tool_call name and args by order.

integrations/ollama/tests/test_chat_generator.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,33 @@ def test_convert_ollama_response_to_chatmessage_with_tools(self):
194194
arguments={"format": "celsius", "location": "Paris, FR"},
195195
)
196196

197+
def test_convert_ollama_response_to_chatmessage_with_repeated_tool(self):
198+
ollama_response = ChatResponse(
199+
model="some_model",
200+
created_at="2023-12-12T14:13:43.416799Z",
201+
message={
202+
"role": "assistant",
203+
"content": "",
204+
"tool_calls": [
205+
{"function": {"name": "weather", "arguments": {"city": "Paris"}}},
206+
{"function": {"name": "weather", "arguments": {"city": "London"}}},
207+
],
208+
},
209+
done=True,
210+
total_duration=5191566416,
211+
load_duration=2154458,
212+
prompt_eval_count=26,
213+
prompt_eval_duration=383809000,
214+
eval_count=298,
215+
eval_duration=4799921000,
216+
)
217+
218+
observed = _convert_ollama_response_to_chatmessage(ollama_response)
219+
220+
assert len(observed.tool_calls) == 2
221+
assert observed.tool_calls[0] == ToolCall(tool_name="weather", arguments={"city": "Paris"})
222+
assert observed.tool_calls[1] == ToolCall(tool_name="weather", arguments={"city": "London"})
223+
197224
def test_build_chunk(self):
198225
generator = OllamaChatGenerator()
199226

@@ -386,8 +413,10 @@ def test_callback(chunk: StreamingChunk):
386413
assert result["replies"][0].text is None
387414
assert result["replies"][0].tool_calls[0].tool_name == "calculator"
388415
assert result["replies"][0].tool_calls[0].arguments == {"expression": "7 * (4 + 2)"}
416+
assert result["replies"][0].tool_calls[0].id is None
389417
assert result["replies"][0].tool_calls[1].tool_name == "factorial"
390418
assert result["replies"][0].tool_calls[1].arguments == {"n": 5}
419+
assert result["replies"][0].tool_calls[1].id is None
391420
assert result["replies"][0].meta["finish_reason"] == "stop"
392421
assert result["replies"][0].meta["model"] == "qwen3:0.6b"
393422

@@ -422,6 +451,123 @@ def test_callback(chunk: StreamingChunk):
422451
assert streaming_chunks[1].tool_calls[0].to_dict() == expected
423452
assert len(streaming_chunks[2].tool_calls) == 0
424453

454+
def test_handle_streaming_response_repeated_tool_calls(self):
455+
ollama_chunks = [
456+
ChatResponse(
457+
model="qwen3:0.6b",
458+
created_at="2025-07-31T14:48:03.471292Z",
459+
done=False,
460+
message=Message(
461+
role="assistant",
462+
content="",
463+
tool_calls=[
464+
Message.ToolCall(
465+
function=Message.ToolCall.Function(name="weather", arguments={"city": "Paris"})
466+
)
467+
],
468+
),
469+
),
470+
ChatResponse(
471+
model="qwen3:0.6b",
472+
created_at="2025-07-31T14:48:03.660179Z",
473+
done=False,
474+
message=Message(
475+
role="assistant",
476+
content="",
477+
tool_calls=[
478+
Message.ToolCall(
479+
function=Message.ToolCall.Function(name="weather", arguments={"city": "London"})
480+
)
481+
],
482+
),
483+
),
484+
ChatResponse(
485+
model="qwen3:0.6b",
486+
created_at="2025-07-31T14:48:03.678729Z",
487+
done=True,
488+
done_reason="stop",
489+
total_duration=774786292,
490+
load_duration=43608375,
491+
prompt_eval_count=217,
492+
prompt_eval_duration=312974541,
493+
eval_count=46,
494+
eval_duration=417069750,
495+
message=Message(role="assistant", content=""),
496+
),
497+
]
498+
499+
generator = OllamaChatGenerator()
500+
result = generator._handle_streaming_response(ollama_chunks, None)
501+
502+
assert len(result["replies"][0].tool_calls) == 2
503+
assert result["replies"][0].tool_calls[0].tool_name == "weather"
504+
assert result["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
505+
assert result["replies"][0].tool_calls[0].id is None
506+
assert result["replies"][0].tool_calls[1].tool_name == "weather"
507+
assert result["replies"][0].tool_calls[1].arguments == {"city": "London"}
508+
assert result["replies"][0].tool_calls[1].id is None
509+
510+
@pytest.mark.asyncio
511+
async def test_handle_streaming_response_async_repeated_tool_calls(self):
512+
ollama_chunks = [
513+
ChatResponse(
514+
model="qwen3:0.6b",
515+
created_at="2025-07-31T14:48:03.471292Z",
516+
done=False,
517+
message=Message(
518+
role="assistant",
519+
content="",
520+
tool_calls=[
521+
Message.ToolCall(
522+
function=Message.ToolCall.Function(name="weather", arguments={"city": "Paris"})
523+
)
524+
],
525+
),
526+
),
527+
ChatResponse(
528+
model="qwen3:0.6b",
529+
created_at="2025-07-31T14:48:03.660179Z",
530+
done=False,
531+
message=Message(
532+
role="assistant",
533+
content="",
534+
tool_calls=[
535+
Message.ToolCall(
536+
function=Message.ToolCall.Function(name="weather", arguments={"city": "London"})
537+
)
538+
],
539+
),
540+
),
541+
ChatResponse(
542+
model="qwen3:0.6b",
543+
created_at="2025-07-31T14:48:03.678729Z",
544+
done=True,
545+
done_reason="stop",
546+
total_duration=774786292,
547+
load_duration=43608375,
548+
prompt_eval_count=217,
549+
prompt_eval_duration=312974541,
550+
eval_count=46,
551+
eval_duration=417069750,
552+
message=Message(role="assistant", content=""),
553+
),
554+
]
555+
556+
async def async_chunks():
557+
for chunk in ollama_chunks:
558+
yield chunk
559+
560+
generator = OllamaChatGenerator()
561+
result = await generator._handle_streaming_response_async(async_chunks(), None)
562+
563+
assert len(result["replies"][0].tool_calls) == 2
564+
assert result["replies"][0].tool_calls[0].tool_name == "weather"
565+
assert result["replies"][0].tool_calls[0].arguments == {"city": "Paris"}
566+
assert result["replies"][0].tool_calls[0].id is None
567+
assert result["replies"][0].tool_calls[1].tool_name == "weather"
568+
assert result["replies"][0].tool_calls[1].arguments == {"city": "London"}
569+
assert result["replies"][0].tool_calls[1].id is None
570+
425571
def test_handle_streaming_response_tool_calls_with_thinking(self):
426572
ollama_chunks = [
427573
ChatResponse(
@@ -536,6 +682,7 @@ def test_callback(chunk: StreamingChunk):
536682
assert result["replies"][0].text is None
537683
assert result["replies"][0].tool_calls[0].tool_name == "add_two_numbers"
538684
assert result["replies"][0].tool_calls[0].arguments == {"a": 2, "b": 2}
685+
assert result["replies"][0].tool_calls[0].id is None
539686
assert result["replies"][0].reasoning.reasoning_text == "Okay, the user is asking 2 plus 2."
540687
assert result["replies"][0].meta["finish_reason"] == "stop"
541688
assert result["replies"][0].meta["model"] == "qwen3:0.6b"
@@ -1306,6 +1453,33 @@ def multiply(a: int, b: int) -> int:
13061453
assert new_response.tool_calls[0].tool_name == "multiply"
13071454
assert new_response.tool_calls[0].arguments == {"a": 5, "b": 10}
13081455

1456+
@pytest.mark.parametrize("streaming_callback", [None, print_streaming_chunk])
1457+
def test_live_run_with_repeated_tool_calls(self, tools, streaming_callback):
1458+
component = OllamaChatGenerator(model="qwen3:0.6b", tools=tools, streaming_callback=streaming_callback)
1459+
tool_invoker = ToolInvoker(tools=tools)
1460+
1461+
messages = [ChatMessage.from_user("What is the weather in Paris and London?")]
1462+
response = component.run(messages)
1463+
1464+
assert len(response["replies"]) == 1
1465+
assistant_msg = response["replies"][0]
1466+
1467+
assert assistant_msg.tool_calls
1468+
assert len(assistant_msg.tool_calls) == 2
1469+
for tc in assistant_msg.tool_calls:
1470+
assert isinstance(tc, ToolCall)
1471+
assert tc.tool_name == "weather"
1472+
assert "city" in tc.arguments
1473+
1474+
cities = {tc.arguments["city"].lower() for tc in assistant_msg.tool_calls}
1475+
assert any("paris" in c for c in cities)
1476+
assert any("london" in c for c in cities)
1477+
1478+
tool_messages = tool_invoker.run(messages=[assistant_msg])["tool_messages"]
1479+
final_response = component.run([*messages, assistant_msg, *tool_messages])
1480+
assert len(final_response["replies"]) == 1
1481+
assert final_response["replies"][0].text
1482+
13091483
def test_live_run_with_tools_and_format(self, tools):
13101484
response_format = {
13111485
"type": "object",

0 commit comments

Comments
 (0)