|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | 19 | import base64 |
| 20 | +import dataclasses |
20 | 21 | from functools import cached_property |
| 22 | +import json |
21 | 23 | import logging |
22 | 24 | import os |
23 | 25 | from typing import Any |
|
31 | 33 | from anthropic import AsyncAnthropic |
32 | 34 | from anthropic import AsyncAnthropicVertex |
33 | 35 | from anthropic import NOT_GIVEN |
| 36 | +from anthropic import NotGiven |
34 | 37 | from anthropic import types as anthropic_types |
35 | 38 | from google.genai import types |
36 | 39 | from pydantic import BaseModel |
|
48 | 51 | logger = logging.getLogger("google_adk." + __name__) |
49 | 52 |
|
50 | 53 |
|
| 54 | +@dataclasses.dataclass |
| 55 | +class _ToolUseAccumulator: |
| 56 | + """Accumulates streamed tool_use content block data.""" |
| 57 | + |
| 58 | + id: str |
| 59 | + name: str |
| 60 | + args_json: str |
| 61 | + |
| 62 | + |
51 | 63 | class ClaudeRequest(BaseModel): |
52 | 64 | system_instruction: str |
53 | 65 | messages: Iterable[anthropic_types.MessageParam] |
@@ -115,12 +127,15 @@ def part_to_message_block( |
115 | 127 | else: |
116 | 128 | content_items.append(str(item)) |
117 | 129 | content = "\n".join(content_items) if content_items else "" |
118 | | - # Handle traditional result format |
119 | | - elif "result" in response_data and response_data["result"]: |
120 | | - # Transformation is required because the content is a list of dict. |
121 | | - # ToolResultBlockParam content doesn't support list of dict. Converting |
122 | | - # to str to prevent anthropic.BadRequestError from being thrown. |
123 | | - content = str(response_data["result"]) |
| 130 | + # We serialize to str here |
| 131 | + # SDK ref: anthropic.types.tool_result_block_param |
| 132 | + # https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/tool_result_block_param.py |
| 133 | + elif "result" in response_data and response_data["result"] is not None: |
| 134 | + result = response_data["result"] |
| 135 | + if isinstance(result, (dict, list)): |
| 136 | + content = json.dumps(result) |
| 137 | + else: |
| 138 | + content = str(result) |
124 | 139 |
|
125 | 140 | return anthropic_types.ToolResultBlockParam( |
126 | 141 | tool_use_id=part.function_response.id or "", |
@@ -305,16 +320,111 @@ async def generate_content_async( |
305 | 320 | if llm_request.tools_dict |
306 | 321 | else NOT_GIVEN |
307 | 322 | ) |
308 | | - # TODO(b/421255973): Enable streaming for anthropic models. |
309 | | - message = await self._anthropic_client.messages.create( |
| 323 | + |
| 324 | + if not stream: |
| 325 | + message = await self._anthropic_client.messages.create( |
| 326 | + model=llm_request.model, |
| 327 | + system=llm_request.config.system_instruction, |
| 328 | + messages=messages, |
| 329 | + tools=tools, |
| 330 | + tool_choice=tool_choice, |
| 331 | + max_tokens=self.max_tokens, |
| 332 | + ) |
| 333 | + yield message_to_generate_content_response(message) |
| 334 | + else: |
| 335 | + async for response in self._generate_content_streaming( |
| 336 | + llm_request, messages, tools, tool_choice |
| 337 | + ): |
| 338 | + yield response |
| 339 | + |
| 340 | + async def _generate_content_streaming( |
| 341 | + self, |
| 342 | + llm_request: LlmRequest, |
| 343 | + messages: list[anthropic_types.MessageParam], |
| 344 | + tools: Union[Iterable[anthropic_types.ToolUnionParam], NotGiven], |
| 345 | + tool_choice: Union[anthropic_types.ToolChoiceParam, NotGiven], |
| 346 | + ) -> AsyncGenerator[LlmResponse, None]: |
| 347 | + """Handles streaming responses from Anthropic models. |
| 348 | +
|
| 349 | + Yields partial LlmResponse objects as content arrives, followed by |
| 350 | + a final aggregated LlmResponse with all content. |
| 351 | + """ |
| 352 | + raw_stream = await self._anthropic_client.messages.create( |
310 | 353 | model=llm_request.model, |
311 | 354 | system=llm_request.config.system_instruction, |
312 | 355 | messages=messages, |
313 | 356 | tools=tools, |
314 | 357 | tool_choice=tool_choice, |
315 | 358 | max_tokens=self.max_tokens, |
| 359 | + stream=True, |
| 360 | + ) |
| 361 | + |
| 362 | + # Track content blocks being built during streaming. |
| 363 | + # Each entry maps a block index to its accumulated state. |
| 364 | + text_blocks: dict[int, str] = {} |
| 365 | + tool_use_blocks: dict[int, _ToolUseAccumulator] = {} |
| 366 | + input_tokens = 0 |
| 367 | + output_tokens = 0 |
| 368 | + |
| 369 | + async for event in raw_stream: |
| 370 | + if event.type == "message_start": |
| 371 | + input_tokens = event.message.usage.input_tokens |
| 372 | + output_tokens = event.message.usage.output_tokens |
| 373 | + |
| 374 | + elif event.type == "content_block_start": |
| 375 | + block = event.content_block |
| 376 | + if isinstance(block, anthropic_types.TextBlock): |
| 377 | + text_blocks[event.index] = block.text |
| 378 | + elif isinstance(block, anthropic_types.ToolUseBlock): |
| 379 | + tool_use_blocks[event.index] = _ToolUseAccumulator( |
| 380 | + id=block.id, |
| 381 | + name=block.name, |
| 382 | + args_json="", |
| 383 | + ) |
| 384 | + |
| 385 | + elif event.type == "content_block_delta": |
| 386 | + delta = event.delta |
| 387 | + if isinstance(delta, anthropic_types.TextDelta): |
| 388 | + text_blocks.setdefault(event.index, "") |
| 389 | + text_blocks[event.index] += delta.text |
| 390 | + yield LlmResponse( |
| 391 | + content=types.Content( |
| 392 | + role="model", |
| 393 | + parts=[types.Part.from_text(text=delta.text)], |
| 394 | + ), |
| 395 | + partial=True, |
| 396 | + ) |
| 397 | + elif isinstance(delta, anthropic_types.InputJSONDelta): |
| 398 | + if event.index in tool_use_blocks: |
| 399 | + tool_use_blocks[event.index].args_json += delta.partial_json |
| 400 | + |
| 401 | + elif event.type == "message_delta": |
| 402 | + output_tokens = event.usage.output_tokens |
| 403 | + |
| 404 | + # Build the final aggregated response with all content. |
| 405 | + all_parts: list[types.Part] = [] |
| 406 | + all_indices = sorted( |
| 407 | + set(list(text_blocks.keys()) + list(tool_use_blocks.keys())) |
| 408 | + ) |
| 409 | + for idx in all_indices: |
| 410 | + if idx in text_blocks: |
| 411 | + all_parts.append(types.Part.from_text(text=text_blocks[idx])) |
| 412 | + if idx in tool_use_blocks: |
| 413 | + acc = tool_use_blocks[idx] |
| 414 | + args = json.loads(acc.args_json) if acc.args_json else {} |
| 415 | + part = types.Part.from_function_call(name=acc.name, args=args) |
| 416 | + part.function_call.id = acc.id |
| 417 | + all_parts.append(part) |
| 418 | + |
| 419 | + yield LlmResponse( |
| 420 | + content=types.Content(role="model", parts=all_parts), |
| 421 | + usage_metadata=types.GenerateContentResponseUsageMetadata( |
| 422 | + prompt_token_count=input_tokens, |
| 423 | + candidates_token_count=output_tokens, |
| 424 | + total_token_count=input_tokens + output_tokens, |
| 425 | + ), |
| 426 | + partial=False, |
316 | 427 | ) |
317 | | - yield message_to_generate_content_response(message) |
318 | 428 |
|
319 | 429 | @cached_property |
320 | 430 | def _anthropic_client(self) -> AsyncAnthropic: |
|
0 commit comments