Skip to content

Commit 5c2503e

Browse files
committed
fix: auto-heal orphaned function_calls to prevent crash loop
When execution is interrupted (e.g., server restart or connection loss) after a function_call but before the function_response is saved, the session becomes unrecoverable because Anthropic/OpenAI require tool_calls to be immediately followed by tool_results. This change detects orphaned function_calls and injects synthetic error responses to gracefully recover the session. Changes: - Add _ORPHANED_CALL_ERROR_RESPONSE constant for error responses - Add _create_synthetic_response_for_orphaned_calls helper function - Detect orphaned calls in _rearrange_events_for_async_function_responses_in_history - Use 'user' as author for synthetic function response events - Add 5 comprehensive test cases for auto-healing behavior Fixes #3971
1 parent 82fa10b commit 5c2503e

File tree

2 files changed

+334
-11
lines changed

2 files changed

+334
-11
lines changed

src/google/adk/flows/llm_flows/contents.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434
logger = logging.getLogger('google_adk.' + __name__)
3535

36+
# Error response for orphaned function calls (issue #3971)
37+
_ORPHANED_CALL_ERROR_RESPONSE = {'error': 'Tool execution was interrupted.'}
38+
3639

3740
class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
3841
"""Builds the contents for the LLM request."""
@@ -77,10 +80,44 @@ async def run_async(
7780
request_processor = _ContentLlmRequestProcessor()
7881

7982

83+
def _create_synthetic_response_for_orphaned_calls(
84+
event: Event,
85+
orphaned_calls: list[types.FunctionCall],
86+
) -> Event:
87+
"""Create synthetic error responses for orphaned function calls."""
88+
error_response = _ORPHANED_CALL_ERROR_RESPONSE
89+
parts: list[types.Part] = []
90+
91+
for func_call in orphaned_calls:
92+
logger.warning(
93+
'Auto-healing orphaned function_call (id=%s, name=%s). '
94+
'This indicates execution was interrupted before tool completion.',
95+
func_call.id,
96+
func_call.name,
97+
)
98+
part = types.Part.from_function_response(
99+
name=func_call.name,
100+
response=error_response,
101+
)
102+
part.function_response.id = func_call.id
103+
parts.append(part)
104+
105+
return Event(
106+
invocation_id=event.invocation_id,
107+
author='user',
108+
content=types.Content(role='user', parts=parts),
109+
branch=event.branch,
110+
)
111+
112+
80113
def _rearrange_events_for_async_function_responses_in_history(
81114
events: list[Event],
82115
) -> list[Event]:
83-
"""Rearrange the async function_response events in the history."""
116+
"""Rearrange the async function_response events in the history.
117+
118+
Also auto-heals orphaned function_calls by injecting synthetic error
119+
responses to prevent crash loops (issue #3971).
120+
"""
84121

85122
function_call_id_to_response_events_index: dict[str, int] = {}
86123
for i, event in enumerate(events):
@@ -96,26 +133,34 @@ def _rearrange_events_for_async_function_responses_in_history(
96133
# function_response should be handled together with function_call below.
97134
continue
98135
elif event.get_function_calls():
99-
100136
function_response_events_indices = set()
137+
orphaned_calls: list[types.FunctionCall] = []
101138
for function_call in event.get_function_calls():
102139
function_call_id = function_call.id
103140
if function_call_id in function_call_id_to_response_events_index:
104141
function_response_events_indices.add(
105142
function_call_id_to_response_events_index[function_call_id]
106143
)
144+
elif function_call_id:
145+
orphaned_calls.append(function_call)
107146
result_events.append(event)
108-
if not function_response_events_indices:
147+
if not function_response_events_indices and not orphaned_calls:
109148
continue
110-
if len(function_response_events_indices) == 1:
111-
result_events.append(
112-
events[next(iter(function_response_events_indices))]
113-
)
114-
else: # Merge all async function_response as one response event
149+
if function_response_events_indices:
150+
if len(function_response_events_indices) == 1:
151+
result_events.append(
152+
events[next(iter(function_response_events_indices))]
153+
)
154+
else: # Merge all async function_response as one response event
155+
result_events.append(
156+
_merge_function_response_events(
157+
[events[i] for i in sorted(function_response_events_indices)]
158+
)
159+
)
160+
# Inject synthetic responses for orphaned calls (issue #3971)
161+
if orphaned_calls:
115162
result_events.append(
116-
_merge_function_response_events(
117-
[events[i] for i in sorted(function_response_events_indices)]
118-
)
163+
_create_synthetic_response_for_orphaned_calls(event, orphaned_calls)
119164
)
120165
continue
121166
else:
@@ -444,6 +489,7 @@ def _get_contents(
444489
result_events = _rearrange_events_for_latest_function_response(
445490
filtered_events
446491
)
492+
# Auto-heal orphaned function_calls to prevent crash loop (issue #3971)
447493
result_events = _rearrange_events_for_async_function_responses_in_history(
448494
result_events
449495
)

tests/unittests/flows/llm_flows/test_contents_function.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Tests for function call/response rearrangement in contents module."""
1616

17+
import logging
18+
1719
from google.adk.agents.llm_agent import Agent
1820
from google.adk.events.event import Event
1921
from google.adk.flows.llm_flows import contents
@@ -590,3 +592,278 @@ async def test_error_when_function_response_without_matching_call():
590592
invocation_context, llm_request
591593
):
592594
pass
595+
596+
597+
@pytest.mark.asyncio
598+
async def test_auto_healing_single_orphaned_function_call():
599+
"""Test auto-healing injects synthetic response for orphaned function call.
600+
601+
When a session is interrupted after a function call but before the response
602+
is saved, the function call becomes orphaned. Auto-healing should inject a
603+
synthetic error response to prevent crash loops when the session resumes.
604+
605+
This test verifies:
606+
- Orphaned function calls are detected
607+
- Synthetic error responses are injected with correct format
608+
- Session can continue without crashing
609+
"""
610+
agent = Agent(model="gemini-2.5-flash", name="test_agent")
611+
llm_request = LlmRequest(model="gemini-2.5-flash")
612+
invocation_context = await testing_utils.create_invocation_context(
613+
agent=agent
614+
)
615+
616+
orphaned_call = types.FunctionCall(
617+
id="orphaned_123", name="get_weather", args={"location": "Seoul"}
618+
)
619+
620+
events = [
621+
Event(
622+
invocation_id="inv1",
623+
author="user",
624+
content=types.UserContent("What is the weather in Seoul?"),
625+
),
626+
Event(
627+
invocation_id="inv2",
628+
author="test_agent",
629+
content=types.ModelContent([types.Part(function_call=orphaned_call)]),
630+
),
631+
# No function_response - execution was interrupted
632+
]
633+
invocation_context.session.events = events
634+
635+
# Process the request - should not crash
636+
async for _ in contents.request_processor.run_async(
637+
invocation_context, llm_request
638+
):
639+
pass
640+
641+
# Verify synthetic response was injected
642+
assert len(llm_request.contents) == 3
643+
644+
synthetic_content = llm_request.contents[2]
645+
assert synthetic_content.role == "user"
646+
assert len(synthetic_content.parts) == 1
647+
648+
synthetic_response = synthetic_content.parts[0].function_response
649+
assert synthetic_response.id == "orphaned_123"
650+
assert synthetic_response.name == "get_weather"
651+
assert synthetic_response.response == contents._ORPHANED_CALL_ERROR_RESPONSE
652+
653+
654+
@pytest.mark.asyncio
655+
async def test_auto_healing_multiple_orphaned_function_calls():
656+
"""Test auto-healing handles multiple orphaned function calls in one event."""
657+
agent = Agent(model="gemini-2.5-flash", name="test_agent")
658+
llm_request = LlmRequest(model="gemini-2.5-flash")
659+
invocation_context = await testing_utils.create_invocation_context(
660+
agent=agent
661+
)
662+
663+
orphaned_call_1 = types.FunctionCall(
664+
id="orphaned_1", name="tool_a", args={"arg": "value1"}
665+
)
666+
orphaned_call_2 = types.FunctionCall(
667+
id="orphaned_2", name="tool_b", args={"arg": "value2"}
668+
)
669+
670+
events = [
671+
Event(
672+
invocation_id="inv1",
673+
author="user",
674+
content=types.UserContent("Run multiple tools"),
675+
),
676+
Event(
677+
invocation_id="inv2",
678+
author="test_agent",
679+
content=types.ModelContent([
680+
types.Part(function_call=orphaned_call_1),
681+
types.Part(function_call=orphaned_call_2),
682+
]),
683+
),
684+
# No function_responses - execution was interrupted
685+
]
686+
invocation_context.session.events = events
687+
688+
# Process the request - should not crash
689+
async for _ in contents.request_processor.run_async(
690+
invocation_context, llm_request
691+
):
692+
pass
693+
694+
# Verify synthetic responses were injected for both calls
695+
assert len(llm_request.contents) == 3
696+
697+
synthetic_content = llm_request.contents[2]
698+
assert synthetic_content.role == "user"
699+
assert len(synthetic_content.parts) == 2
700+
701+
response_ids = {part.function_response.id for part in synthetic_content.parts}
702+
assert response_ids == {"orphaned_1", "orphaned_2"}
703+
704+
705+
@pytest.mark.asyncio
706+
async def test_auto_healing_partial_orphaned_function_calls():
707+
"""Test auto-healing only heals calls without responses.
708+
709+
When some function calls have responses and others don't, only the orphaned
710+
ones should receive synthetic responses.
711+
"""
712+
agent = Agent(model="gemini-2.5-flash", name="test_agent")
713+
llm_request = LlmRequest(model="gemini-2.5-flash")
714+
invocation_context = await testing_utils.create_invocation_context(
715+
agent=agent
716+
)
717+
718+
completed_call = types.FunctionCall(
719+
id="completed_123", name="tool_complete", args={}
720+
)
721+
orphaned_call = types.FunctionCall(
722+
id="orphaned_456", name="tool_orphaned", args={}
723+
)
724+
completed_response = types.FunctionResponse(
725+
id="completed_123",
726+
name="tool_complete",
727+
response={"result": "success"},
728+
)
729+
730+
events = [
731+
Event(
732+
invocation_id="inv1",
733+
author="user",
734+
content=types.UserContent("Run two tools"),
735+
),
736+
Event(
737+
invocation_id="inv2",
738+
author="test_agent",
739+
content=types.ModelContent([
740+
types.Part(function_call=completed_call),
741+
types.Part(function_call=orphaned_call),
742+
]),
743+
),
744+
# Only completed_call has a response
745+
Event(
746+
invocation_id="inv3",
747+
author="user",
748+
content=types.UserContent(
749+
[types.Part(function_response=completed_response)]
750+
),
751+
),
752+
]
753+
invocation_context.session.events = events
754+
755+
# Process the request
756+
async for _ in contents.request_processor.run_async(
757+
invocation_context, llm_request
758+
):
759+
pass
760+
761+
# Verify: completed response + synthetic response for orphaned call
762+
assert len(llm_request.contents) == 4
763+
764+
# Third content should be the completed response
765+
completed_content = llm_request.contents[2]
766+
assert completed_content.parts[0].function_response.id == "completed_123"
767+
768+
# Fourth content should be the synthetic response for orphaned call
769+
synthetic_content = llm_request.contents[3]
770+
assert synthetic_content.parts[0].function_response.id == "orphaned_456"
771+
assert (
772+
synthetic_content.parts[0].function_response.response
773+
== contents._ORPHANED_CALL_ERROR_RESPONSE
774+
)
775+
776+
777+
@pytest.mark.asyncio
778+
async def test_auto_healing_no_healing_when_responses_exist():
779+
"""Test that no healing occurs when all function calls have responses."""
780+
agent = Agent(model="gemini-2.5-flash", name="test_agent")
781+
llm_request = LlmRequest(model="gemini-2.5-flash")
782+
invocation_context = await testing_utils.create_invocation_context(
783+
agent=agent
784+
)
785+
786+
function_call = types.FunctionCall(
787+
id="complete_call", name="search_tool", args={"query": "test"}
788+
)
789+
function_response = types.FunctionResponse(
790+
id="complete_call",
791+
name="search_tool",
792+
response={"results": ["item1"]},
793+
)
794+
795+
events = [
796+
Event(
797+
invocation_id="inv1",
798+
author="user",
799+
content=types.UserContent("Search for test"),
800+
),
801+
Event(
802+
invocation_id="inv2",
803+
author="test_agent",
804+
content=types.ModelContent([types.Part(function_call=function_call)]),
805+
),
806+
Event(
807+
invocation_id="inv3",
808+
author="user",
809+
content=types.UserContent(
810+
[types.Part(function_response=function_response)]
811+
),
812+
),
813+
]
814+
invocation_context.session.events = events
815+
816+
# Process the request
817+
async for _ in contents.request_processor.run_async(
818+
invocation_context, llm_request
819+
):
820+
pass
821+
822+
# Verify no synthetic response was added (only 3 contents)
823+
assert len(llm_request.contents) == 3
824+
# Verify the real response is present, not a synthetic one
825+
assert llm_request.contents[2].parts[0].function_response.response == {
826+
"results": ["item1"]
827+
}
828+
829+
830+
@pytest.mark.asyncio
831+
async def test_auto_healing_logs_warning(caplog):
832+
"""Test that auto-healing logs a warning for each orphaned call."""
833+
agent = Agent(model="gemini-2.5-flash", name="test_agent")
834+
llm_request = LlmRequest(model="gemini-2.5-flash")
835+
invocation_context = await testing_utils.create_invocation_context(
836+
agent=agent
837+
)
838+
839+
orphaned_call = types.FunctionCall(
840+
id="log_test_123", name="test_tool", args={}
841+
)
842+
843+
events = [
844+
Event(
845+
invocation_id="inv1",
846+
author="user",
847+
content=types.UserContent("Test logging"),
848+
),
849+
Event(
850+
invocation_id="inv2",
851+
author="test_agent",
852+
content=types.ModelContent([types.Part(function_call=orphaned_call)]),
853+
),
854+
]
855+
invocation_context.session.events = events
856+
857+
with caplog.at_level(logging.WARNING):
858+
async for _ in contents.request_processor.run_async(
859+
invocation_context, llm_request
860+
):
861+
pass
862+
863+
# Verify warning was logged
864+
assert any(
865+
"Auto-healing orphaned function_call" in record.message
866+
and "log_test_123" in record.message
867+
and "test_tool" in record.message
868+
for record in caplog.records
869+
)

0 commit comments

Comments
 (0)