Skip to content

Commit fe41817

Browse files
xuanyang15copybara-github
authored andcommitted
fix: Generate IDs for FunctionCalls when processing streaming LLM responses
Close: #4609 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 890144644
1 parent 446575f commit fe41817

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

src/google/adk/utils/streaming_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,13 @@ def _process_function_call_part(self, part: types.Part) -> None:
225225
if fc.partial_args or fc.will_continue:
226226
# Streaming function call arguments
227227

228+
# Generate ID on first chunk if not provided by LLM
229+
if not fc.id and not self._current_fc_id:
230+
# Lazy import to avoid circular dependency
231+
from ..flows.llm_flows.functions import generate_client_function_call_id
232+
233+
fc.id = generate_client_function_call_id()
234+
228235
# Save thought_signature from the part (first chunk should have it)
229236
if part.thought_signature and not self._current_thought_signature:
230237
self._current_thought_signature = part.thought_signature
@@ -233,6 +240,12 @@ def _process_function_call_part(self, part: types.Part) -> None:
233240
# Non-streaming function call (standard format with args)
234241
# Skip empty function calls (used as streaming end markers)
235242
if fc.name:
243+
# Generate ID if not provided by LLM
244+
if not fc.id:
245+
# Lazy import to avoid circular dependency
246+
from ..flows.llm_flows.functions import generate_client_function_call_id
247+
248+
fc.id = generate_client_function_call_id()
236249
# Flush any buffered text first, then add the FC part
237250
self._flush_text_buffer_to_sequence()
238251
self._parts_sequence.append(part)

tests/unittests/utils/test_streaming_utils.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from google.adk.features._feature_registry import FeatureName
1818
from google.adk.features._feature_registry import temporary_feature_override
19+
from google.adk.flows.llm_flows.functions import AF_FUNCTION_CALL_ID_PREFIX
1920
from google.adk.utils import streaming_utils
2021
from google.genai import types
2122
import pytest
@@ -304,3 +305,234 @@ async def run_test():
304305
await run_test()
305306
else:
306307
await run_test()
308+
309+
310+
class TestFunctionCallIdGeneration:
311+
"""Tests for function call ID generation in streaming mode.
312+
313+
Regression tests for https://github.com/google/adk-python/issues/4609.
314+
"""
315+
316+
@pytest.mark.asyncio
317+
async def test_non_streaming_fc_generates_id_when_empty(self):
318+
"""Non-streaming function call should get an adk-* ID if LLM didn't provide one."""
319+
with temporary_feature_override(
320+
FeatureName.PROGRESSIVE_SSE_STREAMING, True
321+
):
322+
aggregator = streaming_utils.StreamingResponseAggregator()
323+
324+
response = types.GenerateContentResponse(
325+
candidates=[
326+
types.Candidate(
327+
content=types.Content(
328+
parts=[
329+
types.Part(
330+
function_call=types.FunctionCall(
331+
name="my_tool",
332+
args={"x": 1},
333+
id=None, # No ID from LLM
334+
)
335+
)
336+
]
337+
),
338+
finish_reason=types.FinishReason.STOP,
339+
)
340+
]
341+
)
342+
343+
async for _ in aggregator.process_response(response):
344+
pass
345+
346+
closed_response = aggregator.close()
347+
assert closed_response is not None
348+
fc = closed_response.content.parts[0].function_call
349+
assert fc.id is not None
350+
assert fc.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
351+
352+
@pytest.mark.asyncio
353+
async def test_non_streaming_fc_preserves_llm_assigned_id(self):
354+
"""Non-streaming function call should preserve ID if LLM provided one."""
355+
with temporary_feature_override(
356+
FeatureName.PROGRESSIVE_SSE_STREAMING, True
357+
):
358+
aggregator = streaming_utils.StreamingResponseAggregator()
359+
360+
response = types.GenerateContentResponse(
361+
candidates=[
362+
types.Candidate(
363+
content=types.Content(
364+
parts=[
365+
types.Part(
366+
function_call=types.FunctionCall(
367+
name="my_tool",
368+
args={"x": 1},
369+
id="llm-assigned-id",
370+
)
371+
)
372+
]
373+
),
374+
finish_reason=types.FinishReason.STOP,
375+
)
376+
]
377+
)
378+
379+
async for _ in aggregator.process_response(response):
380+
pass
381+
382+
closed_response = aggregator.close()
383+
assert closed_response is not None
384+
fc = closed_response.content.parts[0].function_call
385+
assert fc.id == "llm-assigned-id"
386+
387+
@pytest.mark.asyncio
388+
async def test_streaming_fc_generates_consistent_id_across_chunks(self):
389+
"""Streaming function call should have the same ID in partial and final responses."""
390+
with temporary_feature_override(
391+
FeatureName.PROGRESSIVE_SSE_STREAMING, True
392+
):
393+
aggregator = streaming_utils.StreamingResponseAggregator()
394+
395+
# First chunk: function call starts
396+
response1 = types.GenerateContentResponse(
397+
candidates=[
398+
types.Candidate(
399+
content=types.Content(
400+
parts=[
401+
types.Part(
402+
function_call=types.FunctionCall(
403+
name="my_tool",
404+
id=None,
405+
partial_args=[
406+
types.PartialArg(
407+
json_path="$.x",
408+
string_value="hello",
409+
)
410+
],
411+
will_continue=True,
412+
)
413+
)
414+
]
415+
)
416+
)
417+
]
418+
)
419+
420+
# Second chunk: function call continues
421+
response2 = types.GenerateContentResponse(
422+
candidates=[
423+
types.Candidate(
424+
content=types.Content(
425+
parts=[
426+
types.Part(
427+
function_call=types.FunctionCall(
428+
name=None,
429+
id=None,
430+
partial_args=[
431+
types.PartialArg(
432+
json_path="$.x",
433+
string_value=" world",
434+
)
435+
],
436+
will_continue=False, # Complete
437+
)
438+
)
439+
]
440+
),
441+
finish_reason=types.FinishReason.STOP,
442+
)
443+
]
444+
)
445+
446+
partial_results = []
447+
async for r in aggregator.process_response(response1):
448+
partial_results.append(r)
449+
async for r in aggregator.process_response(response2):
450+
partial_results.append(r)
451+
452+
closed_response = aggregator.close()
453+
assert closed_response is not None
454+
final_fc = closed_response.content.parts[0].function_call
455+
assert final_fc.id is not None
456+
assert final_fc.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
457+
assert final_fc.args == {"x": "hello world"}
458+
459+
# Verify partial and final events share the same ID
460+
partial_fc = partial_results[0].content.parts[0].function_call
461+
assert (
462+
partial_fc.id == final_fc.id
463+
), f"Partial FC ID ({partial_fc.id!r}) != Final FC ID ({final_fc.id!r})"
464+
465+
@pytest.mark.asyncio
466+
async def test_multiple_streaming_fcs_get_different_ids(self):
467+
"""Multiple function calls arriving in separate chunks should get different IDs."""
468+
with temporary_feature_override(
469+
FeatureName.PROGRESSIVE_SSE_STREAMING, True
470+
):
471+
aggregator = streaming_utils.StreamingResponseAggregator()
472+
473+
# First FC
474+
response1 = types.GenerateContentResponse(
475+
candidates=[
476+
types.Candidate(
477+
content=types.Content(
478+
parts=[
479+
types.Part(
480+
function_call=types.FunctionCall(
481+
name="tool_a",
482+
id=None,
483+
partial_args=[
484+
types.PartialArg(
485+
json_path="$.a", string_value="val_a"
486+
)
487+
],
488+
will_continue=False,
489+
)
490+
)
491+
]
492+
)
493+
)
494+
]
495+
)
496+
497+
# Second FC
498+
response2 = types.GenerateContentResponse(
499+
candidates=[
500+
types.Candidate(
501+
content=types.Content(
502+
parts=[
503+
types.Part(
504+
function_call=types.FunctionCall(
505+
name="tool_b",
506+
id=None,
507+
partial_args=[
508+
types.PartialArg(
509+
json_path="$.b", string_value="val_b"
510+
)
511+
],
512+
will_continue=False,
513+
)
514+
)
515+
]
516+
),
517+
finish_reason=types.FinishReason.STOP,
518+
)
519+
]
520+
)
521+
522+
async for _ in aggregator.process_response(response1):
523+
pass
524+
async for _ in aggregator.process_response(response2):
525+
pass
526+
527+
closed_response = aggregator.close()
528+
assert closed_response is not None
529+
assert len(closed_response.content.parts) == 2
530+
531+
fc_a = closed_response.content.parts[0].function_call
532+
fc_b = closed_response.content.parts[1].function_call
533+
534+
assert fc_a.id is not None
535+
assert fc_b.id is not None
536+
assert fc_a.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
537+
assert fc_b.id.startswith(AF_FUNCTION_CALL_ID_PREFIX)
538+
assert fc_a.id != fc_b.id # Different IDs for different FCs

0 commit comments

Comments
 (0)