Skip to content

Commit 71b936b

Browse files
wikaaaaacopybara-github
authored andcommitted
fix: prevent compaction from orphaning function responses
Co-authored-by: Wiktoria Walczak <wwalczak@google.com> PiperOrigin-RevId: 932325806
1 parent 22adbe1 commit 71b936b

3 files changed

Lines changed: 362 additions & 98 deletions

File tree

src/google/adk/apps/compaction.py

Lines changed: 20 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,7 @@ def _events_to_compact_for_token_threshold(
266266
event_retention_size=event_retention_size,
267267
)
268268
events_to_compact = candidate_events[:split_index]
269-
pending_ids = _pending_function_call_ids(events)
270-
events_to_compact = _truncate_events_before_pending_function_call(
271-
events_to_compact, pending_ids
272-
)
273-
events_to_compact = _truncate_events_before_hitl_signal(
274-
events_to_compact, _resolved_hitl_call_ids(events)
275-
)
269+
events_to_compact = _longest_self_contained_prefix(events_to_compact)
276270
if not events_to_compact:
277271
return []
278272

@@ -313,76 +307,28 @@ def _event_function_response_ids(event: Event) -> set[str]:
313307
return function_response_ids
314308

315309

316-
def _pending_function_call_ids(events: list[Event]) -> set[str]:
317-
"""Returns function call IDs that have no matching response in the session.
310+
def _longest_self_contained_prefix(events: list[Event]) -> list[Event]:
311+
"""Returns the longest prefix of `events` that is safe to compact.
318312
319-
Scans the session once, collecting function call IDs and response IDs, then
320-
returns the call IDs that are not covered by any response. Events containing
321-
these IDs represent pending (unanswered) function calls that must not be
322-
compacted.
313+
Performs a single left-to-right pass tracking "open" obligations keyed by call
314+
id: a function call or a tool-confirmation / auth request opens one, and a
315+
function response with the same id closes it. Responses are applied before
316+
opens within each event so a response only closes an obligation opened by an
317+
earlier event. The prefix is safe to summarize only at points where no
318+
obligation is open, so the longest prefix ending at such a balanced point is
319+
returned (empty if the window never reaches a balanced point).
323320
"""
324-
all_call_ids: set[str] = set()
325-
all_response_ids: set[str] = set()
326-
for event in events:
327-
all_call_ids.update(_event_function_call_ids(event))
328-
all_response_ids.update(_event_function_response_ids(event))
329-
330-
return all_call_ids - all_response_ids
331-
332-
333-
def _has_pending_function_call(event: Event, pending_ids: set[str]) -> bool:
334-
"""Returns True if the event contains any pending function call."""
335-
call_ids = _event_function_call_ids(event)
336-
return bool(call_ids and not call_ids.isdisjoint(pending_ids))
337-
338-
339-
def _truncate_events_before_pending_function_call(
340-
events: list[Event], pending_ids: set[str]
341-
) -> list[Event]:
342-
"""Returns the leading contiguous events that avoid pending function calls."""
343-
for index, event in enumerate(events):
344-
if _has_pending_function_call(event, pending_ids):
345-
return events[:index]
346-
return events
347-
348-
349-
def _resolved_hitl_call_ids(events: list[Event]) -> set[str]:
350-
"""Returns HITL call ids resolved by a later function_response in `events`."""
351-
hitl_position: dict[str, int] = {}
352-
resolved: set[str] = set()
321+
open_ids: set[str] = set()
322+
safe_length = 0
353323
for index, event in enumerate(events):
324+
open_ids -= _event_function_response_ids(event)
325+
open_ids |= _event_function_call_ids(event)
354326
if event.actions:
355-
for call_id in event.actions.requested_tool_confirmations:
356-
hitl_position.setdefault(call_id, index)
357-
for call_id in event.actions.requested_auth_configs:
358-
hitl_position.setdefault(call_id, index)
359-
for resp_id in _event_function_response_ids(event):
360-
hitl_pos = hitl_position.get(resp_id)
361-
if hitl_pos is not None and index > hitl_pos:
362-
resolved.add(resp_id)
363-
return resolved
364-
365-
366-
def _is_pending_hitl(event: Event, resolved_call_ids: set[str]) -> bool:
367-
"""Returns True if the event has an HITL request not in `resolved_call_ids`."""
368-
if not event.actions:
369-
return False
370-
requested = set(event.actions.requested_tool_confirmations) | set(
371-
event.actions.requested_auth_configs
372-
)
373-
if not requested:
374-
return False
375-
return bool(requested - resolved_call_ids)
376-
377-
378-
def _truncate_events_before_hitl_signal(
379-
events: list[Event], resolved_call_ids: set[str]
380-
) -> list[Event]:
381-
"""Returns the leading contiguous events before any pending HITL request."""
382-
for index, event in enumerate(events):
383-
if _is_pending_hitl(event, resolved_call_ids):
384-
return events[:index]
385-
return events
327+
open_ids |= set(event.actions.requested_tool_confirmations)
328+
open_ids |= set(event.actions.requested_auth_configs)
329+
if not open_ids:
330+
safe_length = index + 1
331+
return events[:safe_length]
386332

387333

388334
def _safe_token_compaction_split_index(
@@ -664,13 +610,7 @@ async def _run_compaction_for_sliding_window(
664610
events_to_compact = [
665611
e for e in events_to_compact if not e.actions.compaction
666612
]
667-
pending_ids = _pending_function_call_ids(events)
668-
events_to_compact = _truncate_events_before_pending_function_call(
669-
events_to_compact, pending_ids
670-
)
671-
events_to_compact = _truncate_events_before_hitl_signal(
672-
events_to_compact, _resolved_hitl_call_ids(events)
673-
)
613+
events_to_compact = _longest_self_contained_prefix(events_to_compact)
674614

675615
if not events_to_compact:
676616
return None

0 commit comments

Comments
 (0)