Skip to content

Commit 8befdb8

Browse files
wyf7107copybara-github
authored andcommitted
fix(streaming): Ensure final partial=False frame is always yielded
The StreamingResponseAggregator.close() method previously returned None if it didn't accumulate text or parts, such as for safety blocks or pure function calls. This caused clients (e.g., Vertex AI Reasoning Engine) to hang indefinitely waiting for a partial=False termination frame, and caused loops to break prematurely. This fix ensures close() always returns a final LlmResponse(partial=False) as long as a response exists, carrying over any error_code, error_message, and usage_metadata, regardless of whether PROGRESSIVE_SSE_STREAMING is enabled. Added parameterized unit tests to verify behavior across both streaming modes. Fixes #3754 Co-authored-by: Yifan Wang <wanyif@google.com> PiperOrigin-RevId: 927411984
1 parent 2ad2005 commit 8befdb8

2 files changed

Lines changed: 148 additions & 67 deletions

File tree

src/google/adk/utils/streaming_utils.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -349,61 +349,61 @@ def close(self) -> Optional[LlmResponse]:
349349
Returns:
350350
The aggregated LlmResponse.
351351
"""
352+
if not self._response:
353+
return None
354+
355+
candidate = (
356+
self._response.candidates[0] if self._response.candidates else None
357+
)
358+
359+
finish_reason = self._finish_reason
360+
if not finish_reason and candidate:
361+
finish_reason = candidate.finish_reason
362+
363+
error_code = None
364+
error_message = None
365+
if finish_reason and finish_reason != types.FinishReason.STOP:
366+
error_code = finish_reason
367+
error_message = candidate.finish_message if candidate else None
368+
elif not candidate and self._response.prompt_feedback:
369+
error_code = self._response.prompt_feedback.block_reason
370+
error_message = self._response.prompt_feedback.block_reason_message
371+
352372
# ========== Progressive SSE Streaming (new feature) ==========
353373
if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING):
354-
# Always generate final aggregated response in progressive mode
355-
if self._response and self._response.candidates:
356-
# Flush any remaining buffers to complete the sequence
357-
self._flush_text_buffer_to_sequence()
358-
self._flush_function_call_to_sequence()
359-
360-
# Use the parts sequence which preserves original ordering
361-
final_parts = self._parts_sequence
362-
363-
if final_parts:
364-
candidate = self._response.candidates[0]
365-
finish_reason = self._finish_reason or candidate.finish_reason
366-
367-
return LlmResponse(
368-
content=types.ModelContent(parts=final_parts),
369-
grounding_metadata=self._grounding_metadata,
370-
citation_metadata=self._citation_metadata,
371-
error_code=None
372-
if finish_reason == types.FinishReason.STOP
373-
else finish_reason,
374-
error_message=None
375-
if finish_reason == types.FinishReason.STOP
376-
else candidate.finish_message,
377-
usage_metadata=self._usage_metadata,
378-
finish_reason=finish_reason,
379-
partial=False,
380-
)
381-
382-
return None
374+
self._flush_text_buffer_to_sequence()
375+
self._flush_function_call_to_sequence()
376+
377+
final_parts = self._parts_sequence
378+
content = types.ModelContent(parts=final_parts) if final_parts else None
383379

384-
# ========== Non-Progressive SSE Streaming (old behavior) ==========
385-
if (
386-
(self._text or self._thought_text)
387-
and self._response
388-
and self._response.candidates
389-
):
390-
parts = []
391-
if self._thought_text:
392-
parts.append(types.Part(text=self._thought_text, thought=True))
393-
if self._text:
394-
parts.append(types.Part.from_text(text=self._text))
395-
candidate = self._response.candidates[0]
396380
return LlmResponse(
397-
content=types.ModelContent(parts=parts),
381+
content=content,
398382
grounding_metadata=self._grounding_metadata,
399383
citation_metadata=self._citation_metadata,
400-
error_code=None
401-
if candidate.finish_reason == types.FinishReason.STOP
402-
else candidate.finish_reason,
403-
error_message=None
404-
if candidate.finish_reason == types.FinishReason.STOP
405-
else candidate.finish_message,
384+
error_code=error_code,
385+
error_message=error_message,
406386
usage_metadata=self._usage_metadata,
387+
finish_reason=finish_reason,
388+
partial=False,
407389
)
408390

409-
return None
391+
# ========== Non-Progressive SSE Streaming (old behavior) ==========
392+
parts = []
393+
if self._thought_text:
394+
parts.append(types.Part(text=self._thought_text, thought=True))
395+
if self._text:
396+
parts.append(types.Part.from_text(text=self._text))
397+
content = types.ModelContent(parts=parts) if parts else None
398+
399+
return LlmResponse(
400+
content=content,
401+
grounding_metadata=self._grounding_metadata,
402+
citation_metadata=self._citation_metadata,
403+
error_code=error_code,
404+
error_message=error_message,
405+
usage_metadata=self._usage_metadata,
406+
finish_reason=finish_reason,
407+
partial=False,
408+
)
409+

tests/unittests/utils/test_streaming_utils.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -184,25 +184,106 @@ async def test_close_with_error(self):
184184
assert closed_response.error_message == "Recitation error"
185185

186186
@pytest.mark.asyncio
187-
async def test_process_response_with_none_content(self):
188-
"""Test that StreamingResponseAggregator handles content=None."""
189-
aggregator = streaming_utils.StreamingResponseAggregator()
190-
response = types.GenerateContentResponse(
191-
candidates=[
192-
types.Candidate(
193-
content=types.Content(parts=[]),
194-
finish_reason=types.FinishReason.STOP,
195-
)
196-
]
197-
)
198-
results = []
199-
async for r in aggregator.process_response(response):
200-
results.append(r)
201-
assert len(results) == 1
202-
assert results[0].content is not None
187+
@pytest.mark.parametrize("use_progressive_sse", [True, False])
188+
async def test_empty_content_produces_empty_final_frame(
189+
self, use_progressive_sse
190+
):
191+
"""A candidate with an empty parts list produces an empty final frame."""
192+
with temporary_feature_override(
193+
FeatureName.PROGRESSIVE_SSE_STREAMING, use_progressive_sse
194+
):
195+
aggregator = streaming_utils.StreamingResponseAggregator()
196+
response = types.GenerateContentResponse(
197+
candidates=[
198+
types.Candidate(
199+
content=types.Content(parts=[]),
200+
finish_reason=types.FinishReason.STOP,
201+
)
202+
]
203+
)
204+
results = []
205+
async for r in aggregator.process_response(response):
206+
results.append(r)
207+
closed_response = aggregator.close()
208+
209+
assert len(results) == 1
210+
assert results[0].content is not None
211+
assert closed_response is not None
212+
assert closed_response.partial is False
213+
assert closed_response.content is None
214+
assert closed_response.finish_reason == types.FinishReason.STOP
215+
216+
@pytest.mark.asyncio
217+
@pytest.mark.parametrize("use_progressive_sse", [True, False])
218+
async def test_prompt_feedback_block_returns_error_frame(
219+
self, use_progressive_sse
220+
):
221+
"""A prompt-level safety block produces a final frame with the error code."""
222+
with temporary_feature_override(
223+
FeatureName.PROGRESSIVE_SSE_STREAMING, use_progressive_sse
224+
):
225+
aggregator = streaming_utils.StreamingResponseAggregator()
226+
response = types.GenerateContentResponse(
227+
prompt_feedback=types.GenerateContentResponsePromptFeedback(
228+
block_reason=types.BlockedReason.SAFETY,
229+
block_reason_message="Blocked by safety",
230+
)
231+
)
232+
results = []
233+
async for r in aggregator.process_response(response):
234+
results.append(r)
235+
closed_response = aggregator.close()
236+
237+
assert len(results) == 1
238+
assert closed_response is not None
239+
assert closed_response.partial is False
240+
assert closed_response.error_code == types.BlockedReason.SAFETY
241+
assert closed_response.error_message == "Blocked by safety"
242+
assert closed_response.content is None
243+
244+
@pytest.mark.asyncio
245+
@pytest.mark.parametrize("use_progressive_sse", [True, False])
246+
async def test_pure_function_call_behavior_differs_by_mode(
247+
self, use_progressive_sse
248+
):
249+
"""A pure function call yields the part in progressive mode and an empty frame otherwise."""
250+
with temporary_feature_override(
251+
FeatureName.PROGRESSIVE_SSE_STREAMING, use_progressive_sse
252+
):
253+
aggregator = streaming_utils.StreamingResponseAggregator()
254+
response = types.GenerateContentResponse(
255+
candidates=[
256+
types.Candidate(
257+
content=types.Content(
258+
parts=[
259+
types.Part(
260+
function_call=types.FunctionCall(
261+
name="my_tool",
262+
args={"x": 1},
263+
)
264+
)
265+
]
266+
),
267+
finish_reason=types.FinishReason.STOP,
268+
)
269+
]
270+
)
271+
272+
results = []
273+
async for r in aggregator.process_response(response):
274+
results.append(r)
275+
closed_response = aggregator.close()
276+
277+
assert closed_response is not None
278+
assert closed_response.partial is False
279+
280+
if use_progressive_sse:
281+
assert closed_response.content is not None
282+
assert len(closed_response.content.parts) == 1
283+
assert closed_response.content.parts[0].function_call.name == "my_tool"
284+
else:
285+
assert closed_response.content is None
203286

204-
closed_response = aggregator.close()
205-
assert closed_response is None
206287

207288
@pytest.mark.asyncio
208289
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)