Skip to content

Commit 5770cd3

Browse files
GWealecopybara-github
authored andcommitted
feat: Add streaming support for Anthropic models
Refactor ToolResultBlockParam content handling to use json.dumps for dict/list results. Implement _generate_content_streaming to handle Anthropic's streaming API Close #3250 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 877613612
1 parent 80c5a24 commit 5770cd3

File tree

2 files changed

+474
-9
lines changed

2 files changed

+474
-9
lines changed

src/google/adk/models/anthropic_llm.py

Lines changed: 119 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from __future__ import annotations
1818

1919
import base64
20+
import dataclasses
2021
from functools import cached_property
22+
import json
2123
import logging
2224
import os
2325
from typing import Any
@@ -31,6 +33,7 @@
3133
from anthropic import AsyncAnthropic
3234
from anthropic import AsyncAnthropicVertex
3335
from anthropic import NOT_GIVEN
36+
from anthropic import NotGiven
3437
from anthropic import types as anthropic_types
3538
from google.genai import types
3639
from pydantic import BaseModel
@@ -48,6 +51,15 @@
4851
logger = logging.getLogger("google_adk." + __name__)
4952

5053

54+
@dataclasses.dataclass
55+
class _ToolUseAccumulator:
56+
"""Accumulates streamed tool_use content block data."""
57+
58+
id: str
59+
name: str
60+
args_json: str
61+
62+
5163
class ClaudeRequest(BaseModel):
5264
system_instruction: str
5365
messages: Iterable[anthropic_types.MessageParam]
@@ -115,12 +127,15 @@ def part_to_message_block(
115127
else:
116128
content_items.append(str(item))
117129
content = "\n".join(content_items) if content_items else ""
118-
# Handle traditional result format
119-
elif "result" in response_data and response_data["result"]:
120-
# Transformation is required because the content is a list of dict.
121-
# ToolResultBlockParam content doesn't support list of dict. Converting
122-
# to str to prevent anthropic.BadRequestError from being thrown.
123-
content = str(response_data["result"])
130+
# We serialize to str here
131+
# SDK ref: anthropic.types.tool_result_block_param
132+
# https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_result_block_param.py
133+
elif "result" in response_data and response_data["result"] is not None:
134+
result = response_data["result"]
135+
if isinstance(result, (dict, list)):
136+
content = json.dumps(result)
137+
else:
138+
content = str(result)
124139

125140
return anthropic_types.ToolResultBlockParam(
126141
tool_use_id=part.function_response.id or "",
@@ -305,16 +320,111 @@ async def generate_content_async(
305320
if llm_request.tools_dict
306321
else NOT_GIVEN
307322
)
308-
# TODO(b/421255973): Enable streaming for anthropic models.
309-
message = await self._anthropic_client.messages.create(
323+
324+
if not stream:
325+
message = await self._anthropic_client.messages.create(
326+
model=llm_request.model,
327+
system=llm_request.config.system_instruction,
328+
messages=messages,
329+
tools=tools,
330+
tool_choice=tool_choice,
331+
max_tokens=self.max_tokens,
332+
)
333+
yield message_to_generate_content_response(message)
334+
else:
335+
async for response in self._generate_content_streaming(
336+
llm_request, messages, tools, tool_choice
337+
):
338+
yield response
339+
340+
async def _generate_content_streaming(
341+
self,
342+
llm_request: LlmRequest,
343+
messages: list[anthropic_types.MessageParam],
344+
tools: Union[Iterable[anthropic_types.ToolUnionParam], NotGiven],
345+
tool_choice: Union[anthropic_types.ToolChoiceParam, NotGiven],
346+
) -> AsyncGenerator[LlmResponse, None]:
347+
"""Handles streaming responses from Anthropic models.
348+
349+
Yields partial LlmResponse objects as content arrives, followed by
350+
a final aggregated LlmResponse with all content.
351+
"""
352+
raw_stream = await self._anthropic_client.messages.create(
310353
model=llm_request.model,
311354
system=llm_request.config.system_instruction,
312355
messages=messages,
313356
tools=tools,
314357
tool_choice=tool_choice,
315358
max_tokens=self.max_tokens,
359+
stream=True,
360+
)
361+
362+
# Track content blocks being built during streaming.
363+
# Each entry maps a block index to its accumulated state.
364+
text_blocks: dict[int, str] = {}
365+
tool_use_blocks: dict[int, _ToolUseAccumulator] = {}
366+
input_tokens = 0
367+
output_tokens = 0
368+
369+
async for event in raw_stream:
370+
if event.type == "message_start":
371+
input_tokens = event.message.usage.input_tokens
372+
output_tokens = event.message.usage.output_tokens
373+
374+
elif event.type == "content_block_start":
375+
block = event.content_block
376+
if isinstance(block, anthropic_types.TextBlock):
377+
text_blocks[event.index] = block.text
378+
elif isinstance(block, anthropic_types.ToolUseBlock):
379+
tool_use_blocks[event.index] = _ToolUseAccumulator(
380+
id=block.id,
381+
name=block.name,
382+
args_json="",
383+
)
384+
385+
elif event.type == "content_block_delta":
386+
delta = event.delta
387+
if isinstance(delta, anthropic_types.TextDelta):
388+
text_blocks.setdefault(event.index, "")
389+
text_blocks[event.index] += delta.text
390+
yield LlmResponse(
391+
content=types.Content(
392+
role="model",
393+
parts=[types.Part.from_text(text=delta.text)],
394+
),
395+
partial=True,
396+
)
397+
elif isinstance(delta, anthropic_types.InputJSONDelta):
398+
if event.index in tool_use_blocks:
399+
tool_use_blocks[event.index].args_json += delta.partial_json
400+
401+
elif event.type == "message_delta":
402+
output_tokens = event.usage.output_tokens
403+
404+
# Build the final aggregated response with all content.
405+
all_parts: list[types.Part] = []
406+
all_indices = sorted(
407+
set(list(text_blocks.keys()) + list(tool_use_blocks.keys()))
408+
)
409+
for idx in all_indices:
410+
if idx in text_blocks:
411+
all_parts.append(types.Part.from_text(text=text_blocks[idx]))
412+
if idx in tool_use_blocks:
413+
acc = tool_use_blocks[idx]
414+
args = json.loads(acc.args_json) if acc.args_json else {}
415+
part = types.Part.from_function_call(name=acc.name, args=args)
416+
part.function_call.id = acc.id
417+
all_parts.append(part)
418+
419+
yield LlmResponse(
420+
content=types.Content(role="model", parts=all_parts),
421+
usage_metadata=types.GenerateContentResponseUsageMetadata(
422+
prompt_token_count=input_tokens,
423+
candidates_token_count=output_tokens,
424+
total_token_count=input_tokens + output_tokens,
425+
),
426+
partial=False,
316427
)
317-
yield message_to_generate_content_response(message)
318428

319429
@cached_property
320430
def _anthropic_client(self) -> AsyncAnthropic:

0 commit comments

Comments
 (0)