Skip to content

Commit e99b028

Browse files
committed
feat: add thinking config for anthropic llm
1 parent 60b9073 commit e99b028

2 files changed

Lines changed: 720 additions & 2 deletions

File tree

src/google/adk/models/anthropic_llm.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ class _ToolUseAccumulator:
6262
args_json: str
6363

6464

65+
class _ThinkingAccumulator(BaseModel):
66+
"""Accumulates streamed thinking content block data."""
67+
68+
thinking: str = ""
69+
signature: str = ""
70+
71+
6572
class ClaudeRequest(BaseModel):
6673
system_instruction: str
6774
messages: Iterable[anthropic_types.MessageParam]
@@ -108,7 +115,24 @@ def part_to_message_block(
108115
anthropic_types.DocumentBlockParam,
109116
anthropic_types.ToolUseBlockParam,
110117
anthropic_types.ToolResultBlockParam,
118+
anthropic_types.ThinkingBlockParam,
119+
anthropic_types.RedactedThinkingBlockParam,
111120
]:
121+
if part.thought:
122+
signature_str = (
123+
part.thought_signature.decode("utf-8") if part.thought_signature else ""
124+
)
125+
if part.text:
126+
return anthropic_types.ThinkingBlockParam(
127+
type="thinking",
128+
thinking=part.text,
129+
signature=signature_str,
130+
)
131+
else:
132+
return anthropic_types.RedactedThinkingBlockParam(
133+
type="redacted_thinking",
134+
data=signature_str,
135+
)
112136
if part.text:
113137
return anthropic_types.TextBlockParam(text=part.text, type="text")
114138
elif part.function_call:
@@ -229,6 +253,18 @@ def content_block_to_part(
229253
)
230254
part.function_call.id = content_block.id
231255
return part
256+
if isinstance(content_block, anthropic_types.ThinkingBlock):
257+
return types.Part(
258+
text=content_block.thinking,
259+
thought=True,
260+
thought_signature=content_block.signature.encode("utf-8"),
261+
)
262+
if isinstance(content_block, anthropic_types.RedactedThinkingBlock):
263+
return types.Part(
264+
text="",
265+
thought=True,
266+
thought_signature=content_block.data.encode("utf-8"),
267+
)
232268
raise NotImplementedError("Not supported yet.")
233269

234270

@@ -349,6 +385,26 @@ def function_declaration_to_tool_param(
349385
)
350386

351387

388+
def _build_thinking_param(
389+
thinking_config: Optional[types.ThinkingConfig],
390+
max_tokens: int,
391+
) -> Union[anthropic_types.ThinkingConfigEnabledParam, NotGiven]:
392+
"""Converts ADK ThinkingConfig to Anthropic ThinkingConfigEnabledParam.
393+
394+
Returns NOT_GIVEN if thinking is not configured or budget is 0.
395+
Clamps budget_tokens to max_tokens - 1 to satisfy the API constraint.
396+
"""
397+
if thinking_config is None:
398+
return NOT_GIVEN
399+
budget = thinking_config.thinking_budget
400+
if not budget:
401+
return NOT_GIVEN
402+
return anthropic_types.ThinkingConfigEnabledParam(
403+
type="enabled",
404+
budget_tokens=min(budget, max_tokens - 1),
405+
)
406+
407+
352408
class AnthropicLlm(BaseLlm):
353409
"""Integration with Claude models via the Anthropic API.
354410
@@ -401,6 +457,10 @@ async def generate_content_async(
401457
if llm_request.tools_dict
402458
else NOT_GIVEN
403459
)
460+
thinking = _build_thinking_param(
461+
llm_request.config.thinking_config if llm_request.config else None,
462+
self.max_tokens,
463+
)
404464

405465
if not stream:
406466
message = await self._anthropic_client.messages.create(
@@ -410,11 +470,12 @@ async def generate_content_async(
410470
tools=tools,
411471
tool_choice=tool_choice,
412472
max_tokens=self.max_tokens,
473+
thinking=thinking,
413474
)
414475
yield message_to_generate_content_response(message)
415476
else:
416477
async for response in self._generate_content_streaming(
417-
llm_request, messages, tools, tool_choice
478+
llm_request, messages, tools, tool_choice, thinking
418479
):
419480
yield response
420481

@@ -424,6 +485,9 @@ async def _generate_content_streaming(
424485
messages: list[anthropic_types.MessageParam],
425486
tools: Union[Iterable[anthropic_types.ToolUnionParam], NotGiven],
426487
tool_choice: Union[anthropic_types.ToolChoiceParam, NotGiven],
488+
thinking: Union[
489+
anthropic_types.ThinkingConfigEnabledParam, NotGiven
490+
] = NOT_GIVEN,
427491
) -> AsyncGenerator[LlmResponse, None]:
428492
"""Handles streaming responses from Anthropic models.
429493
@@ -439,12 +503,15 @@ async def _generate_content_streaming(
439503
tool_choice=tool_choice,
440504
max_tokens=self.max_tokens,
441505
stream=True,
506+
thinking=thinking,
442507
)
443508

444509
# Track content blocks being built during streaming.
445510
# Each entry maps a block index to its accumulated state.
446511
text_blocks: dict[int, str] = {}
447512
tool_use_blocks: dict[int, _ToolUseAccumulator] = {}
513+
thinking_blocks: dict[int, _ThinkingAccumulator] = {}
514+
redacted_thinking_blocks: dict[int, str] = {}
448515
input_tokens = 0
449516
output_tokens = 0
450517

@@ -463,6 +530,10 @@ async def _generate_content_streaming(
463530
name=block.name,
464531
args_json="",
465532
)
533+
elif isinstance(block, anthropic_types.ThinkingBlock):
534+
thinking_blocks[event.index] = _ThinkingAccumulator()
535+
elif isinstance(block, anthropic_types.RedactedThinkingBlock):
536+
redacted_thinking_blocks[event.index] = block.data
466537

467538
elif event.type == "content_block_delta":
468539
delta = event.delta
@@ -479,16 +550,44 @@ async def _generate_content_streaming(
479550
elif isinstance(delta, anthropic_types.InputJSONDelta):
480551
if event.index in tool_use_blocks:
481552
tool_use_blocks[event.index].args_json += delta.partial_json
553+
elif isinstance(delta, anthropic_types.ThinkingDelta):
554+
if event.index in thinking_blocks:
555+
thinking_blocks[event.index].thinking += delta.thinking
556+
elif isinstance(delta, anthropic_types.SignatureDelta):
557+
if event.index in thinking_blocks:
558+
thinking_blocks[event.index].signature = delta.signature
482559

483560
elif event.type == "message_delta":
484561
output_tokens = event.usage.output_tokens
485562

486563
# Build the final aggregated response with all content.
487564
all_parts: list[types.Part] = []
488565
all_indices = sorted(
489-
set(list(text_blocks.keys()) + list(tool_use_blocks.keys()))
566+
set(
567+
list(text_blocks.keys())
568+
+ list(tool_use_blocks.keys())
569+
+ list(thinking_blocks.keys())
570+
+ list(redacted_thinking_blocks.keys())
571+
)
490572
)
491573
for idx in all_indices:
574+
if idx in thinking_blocks:
575+
acc = thinking_blocks[idx]
576+
all_parts.append(
577+
types.Part(
578+
text=acc.thinking,
579+
thought=True,
580+
thought_signature=acc.signature.encode("utf-8"),
581+
)
582+
)
583+
if idx in redacted_thinking_blocks:
584+
all_parts.append(
585+
types.Part(
586+
text="",
587+
thought=True,
588+
thought_signature=redacted_thinking_blocks[idx].encode("utf-8"),
589+
)
590+
)
492591
if idx in text_blocks:
493592
all_parts.append(types.Part.from_text(text=text_blocks[idx]))
494593
if idx in tool_use_blocks:

0 commit comments

Comments
 (0)