Skip to content

Commit f0c787f

Browse files
GWealecopybara-github
authored andcommitted
fix: preserve function call IDs for Anthropic models
Function call IDs are now preserved during session replay for Anthropic models, matching the behavior of Gemini models using the interactions API. Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 905250493
1 parent c263426 commit f0c787f

3 files changed

Lines changed: 139 additions & 2 deletions

File tree

src/google/adk/flows/llm_flows/contents.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,22 @@ async def run_async(
4747
preserve_function_call_ids = False
4848
if hasattr(agent, 'canonical_model'):
4949
canonical_model = agent.canonical_model
50-
preserve_function_call_ids = (
50+
if (
5151
isinstance(canonical_model, Gemini)
5252
and canonical_model.use_interactions_api
53-
)
53+
):
54+
preserve_function_call_ids = True
55+
else:
56+
# Anthropic pairs tool_use/tool_result by id, so `adk-*` fallback
57+
# ids must survive replay.
58+
try:
59+
from ...models.anthropic_llm import AnthropicLlm
60+
except ImportError:
61+
AnthropicLlm = None
62+
if AnthropicLlm is not None and isinstance(
63+
canonical_model, AnthropicLlm
64+
):
65+
preserve_function_call_ids = True
5466

5567
# Preserve all contents that were added by instruction processor
5668
# (since llm_request.contents will be completely reassigned below)

tests/unittests/flows/llm_flows/test_contents.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,80 @@ async def test_adk_function_call_ids_preserved_for_interactions_model():
10701070
assert user_fr_part.function_response.id == function_call_id
10711071

10721072

1073+
@pytest.mark.asyncio
1074+
async def test_adk_function_call_ids_preserved_for_anthropic_model():
1075+
"""Anthropic ids must round-trip through replay so Claude can match
1076+
tool_use blocks with their tool_result blocks (issue #5074).
1077+
"""
1078+
from google.adk.models.anthropic_llm import AnthropicLlm
1079+
1080+
agent = Agent(
1081+
model=AnthropicLlm(model="claude-sonnet-4-20250514"),
1082+
name="test_agent",
1083+
)
1084+
llm_request = LlmRequest(model="claude-sonnet-4-20250514")
1085+
invocation_context = await testing_utils.create_invocation_context(
1086+
agent=agent
1087+
)
1088+
1089+
# ADK fallback ids look like ``adk-<uuid>`` and would previously be
1090+
# stripped to None for non-Gemini models on the replay path.
1091+
function_call_id = "adk-test-call-id"
1092+
events = [
1093+
Event(
1094+
invocation_id="inv1",
1095+
author="user",
1096+
content=types.UserContent("Call the tool"),
1097+
),
1098+
Event(
1099+
invocation_id="inv2",
1100+
author="test_agent",
1101+
content=types.Content(
1102+
role="model",
1103+
parts=[
1104+
types.Part(
1105+
function_call=types.FunctionCall(
1106+
id=function_call_id,
1107+
name="test_tool",
1108+
args={"x": 1},
1109+
)
1110+
)
1111+
],
1112+
),
1113+
),
1114+
Event(
1115+
invocation_id="inv3",
1116+
author="test_agent",
1117+
content=types.Content(
1118+
role="user",
1119+
parts=[
1120+
types.Part(
1121+
function_response=types.FunctionResponse(
1122+
id=function_call_id,
1123+
name="test_tool",
1124+
response={"result": 2},
1125+
)
1126+
)
1127+
],
1128+
),
1129+
),
1130+
]
1131+
invocation_context.session.events = events
1132+
1133+
async for _ in contents.request_processor.run_async(
1134+
invocation_context, llm_request
1135+
):
1136+
pass
1137+
1138+
model_fc_part = llm_request.contents[1].parts[0]
1139+
assert model_fc_part.function_call is not None
1140+
assert model_fc_part.function_call.id == function_call_id
1141+
1142+
user_fr_part = llm_request.contents[2].parts[0]
1143+
assert user_fr_part.function_response is not None
1144+
assert user_fr_part.function_response.id == function_call_id
1145+
1146+
10731147
def test_is_other_agent_reply_live_session():
10741148
"""Test _is_other_agent_reply when live_session_id is present."""
10751149
event = Event(author="another_agent", live_session_id="session_123")

tests/unittests/models/test_anthropic_llm.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,3 +1350,54 @@ async def test_non_streaming_does_not_pass_stream_param():
13501350
mock_client.messages.create.assert_called_once()
13511351
_, kwargs = mock_client.messages.create.call_args
13521352
assert "stream" not in kwargs
1353+
1354+
1355+
def test_part_to_message_block_function_call_preserves_valid_id():
1356+
"""Valid Anthropic ids must round-trip byte-for-byte."""
1357+
part = types.Part.from_function_call(name="test_tool", args={"k": "v"})
1358+
part.function_call.id = "toolu_01abc"
1359+
1360+
result = part_to_message_block(part)
1361+
1362+
assert result["id"] == "toolu_01abc"
1363+
1364+
1365+
def test_part_to_message_block_function_response_preserves_valid_id():
1366+
"""function_response ids must round-trip byte-for-byte to tool_use_id."""
1367+
part = types.Part.from_function_response(
1368+
name="test_tool", response={"result": "ok"}
1369+
)
1370+
part.function_response.id = "toolu_01abc"
1371+
1372+
result = part_to_message_block(part)
1373+
1374+
assert result["tool_use_id"] == "toolu_01abc"
1375+
1376+
1377+
def test_part_to_message_block_preserves_adk_fallback_id():
1378+
"""ADK-generated ``adk-<uuid>`` ids match Anthropic's regex and round-trip.
1379+
1380+
This is the path exercised by the contents.py fix: when Vertex Claude
1381+
returns id=None, ``populate_client_function_call_id`` writes ``adk-<uuid>``,
1382+
and contents.py preserves it through replay. ``part_to_message_block`` must
1383+
pass it through to Anthropic unchanged so call/response stay paired.
1384+
"""
1385+
call_part = types.Part.from_function_call(name="t", args={"a": 1})
1386+
call_part.function_call.id = "adk-12345678-1234-1234-1234-123456789012"
1387+
response_part = types.Part.from_function_response(
1388+
name="t", response={"result": "ok"}
1389+
)
1390+
response_part.function_response.id = (
1391+
"adk-12345678-1234-1234-1234-123456789012"
1392+
)
1393+
1394+
call_result = part_to_message_block(call_part)
1395+
response_result = part_to_message_block(response_part)
1396+
1397+
assert call_result["id"] == "adk-12345678-1234-1234-1234-123456789012"
1398+
assert (
1399+
response_result["tool_use_id"]
1400+
== "adk-12345678-1234-1234-1234-123456789012"
1401+
)
1402+
# The pair must remain matched after conversion.
1403+
assert call_result["id"] == response_result["tool_use_id"]

0 commit comments

Comments
 (0)