Skip to content

Commit e600e60

Browse files
committed
fix(models): propagate Anthropic finish reasons
1 parent 684a6e7 commit e600e60

2 files changed

Lines changed: 77 additions & 5 deletions

File tree

src/google/adk/models/anthropic_llm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,14 @@ def to_google_genai_finish_reason(
134134
anthropic_stop_reason: Optional[str],
135135
) -> types.FinishReason:
136136
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
137-
return "STOP"
137+
return types.FinishReason.STOP
138+
if anthropic_stop_reason == "pause_turn":
139+
return types.FinishReason.STOP
138140
if anthropic_stop_reason == "max_tokens":
139-
return "MAX_TOKENS"
140-
return "FINISH_REASON_UNSPECIFIED"
141+
return types.FinishReason.MAX_TOKENS
142+
if anthropic_stop_reason == "refusal":
143+
return types.FinishReason.SAFETY
144+
return types.FinishReason.FINISH_REASON_UNSPECIFIED
141145

142146

143147
def _is_image_part(part: types.Part) -> bool:
@@ -343,8 +347,7 @@ def message_to_generate_content_response(
343347
message.usage.input_tokens + message.usage.output_tokens
344348
),
345349
),
346-
# TODO: Deal with these later.
347-
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
350+
finish_reason=to_google_genai_finish_reason(message.stop_reason),
348351
)
349352

350353

@@ -547,6 +550,7 @@ async def _generate_content_streaming(
547550
redacted_thinking_blocks: dict[int, str] = {}
548551
input_tokens = 0
549552
output_tokens = 0
553+
finish_reason: types.FinishReason | None = None
550554

551555
async for event in raw_stream:
552556
if event.type == "message_start":
@@ -603,6 +607,7 @@ async def _generate_content_streaming(
603607

604608
elif event.type == "message_delta":
605609
output_tokens = event.usage.output_tokens
610+
finish_reason = to_google_genai_finish_reason(event.delta.stop_reason)
606611

607612
# Build the final aggregated response with all content.
608613
all_parts: list[types.Part] = []
@@ -644,6 +649,7 @@ async def _generate_content_streaming(
644649
candidates_token_count=output_tokens,
645650
total_token_count=input_tokens + output_tokens,
646651
),
652+
finish_reason=finish_reason,
647653
partial=False,
648654
)
649655

tests/unittests/models/test_anthropic_llm.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.adk.models.anthropic_llm import content_to_message_param
2929
from google.adk.models.anthropic_llm import function_declaration_to_tool_param
3030
from google.adk.models.anthropic_llm import part_to_message_block
31+
from google.adk.models.anthropic_llm import to_google_genai_finish_reason
3132
from google.adk.models.llm_request import LlmRequest
3233
from google.adk.models.llm_response import LlmResponse
3334
from google.genai import types
@@ -139,6 +140,23 @@ def test_supported_models():
139140
assert models[1] == r"claude-.*-4.*"
140141

141142

143+
@pytest.mark.parametrize(
144+
("stop_reason", "expected"),
145+
[
146+
("end_turn", types.FinishReason.STOP),
147+
("stop_sequence", types.FinishReason.STOP),
148+
("tool_use", types.FinishReason.STOP),
149+
("pause_turn", types.FinishReason.STOP),
150+
("max_tokens", types.FinishReason.MAX_TOKENS),
151+
("refusal", types.FinishReason.SAFETY),
152+
(None, types.FinishReason.FINISH_REASON_UNSPECIFIED),
153+
("unknown_reason", types.FinishReason.FINISH_REASON_UNSPECIFIED),
154+
],
155+
)
156+
def test_to_google_genai_finish_reason(stop_reason, expected):
157+
assert to_google_genai_finish_reason(stop_reason) == expected
158+
159+
142160
function_declaration_test_cases = [
143161
(
144162
"function_with_no_parameters",
@@ -1350,6 +1368,54 @@ async def test_non_streaming_does_not_pass_stream_param():
13501368
mock_client.messages.create.assert_called_once()
13511369
_, kwargs = mock_client.messages.create.call_args
13521370
assert "stream" not in kwargs
1371+
assert responses[0].finish_reason == types.FinishReason.STOP
1372+
1373+
1374+
@pytest.mark.asyncio
1375+
async def test_streaming_sets_finish_reason_from_message_delta():
1376+
llm = AnthropicLlm(model="claude-sonnet-4-20250514")
1377+
1378+
events = [
1379+
MagicMock(
1380+
type="message_start",
1381+
message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)),
1382+
),
1383+
MagicMock(
1384+
type="content_block_start",
1385+
index=0,
1386+
content_block=anthropic_types.TextBlock(text="", type="text"),
1387+
),
1388+
MagicMock(
1389+
type="content_block_delta",
1390+
index=0,
1391+
delta=anthropic_types.TextDelta(text="Hello", type="text_delta"),
1392+
),
1393+
MagicMock(type="content_block_stop", index=0),
1394+
MagicMock(
1395+
type="message_delta",
1396+
delta=MagicMock(stop_reason="max_tokens"),
1397+
usage=MagicMock(output_tokens=3),
1398+
),
1399+
MagicMock(type="message_stop"),
1400+
]
1401+
1402+
mock_client = MagicMock()
1403+
mock_client.messages.create = AsyncMock(
1404+
return_value=_make_mock_stream_events(events)
1405+
)
1406+
1407+
llm_request = LlmRequest(
1408+
model="claude-sonnet-4-20250514",
1409+
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
1410+
config=types.GenerateContentConfig(system_instruction="Test"),
1411+
)
1412+
1413+
with mock.patch.object(llm, "_anthropic_client", mock_client):
1414+
responses = [
1415+
r async for r in llm.generate_content_async(llm_request, stream=True)
1416+
]
1417+
1418+
assert responses[-1].finish_reason == types.FinishReason.MAX_TOKENS
13531419

13541420

13551421
def test_part_to_message_block_function_call_preserves_valid_id():

0 commit comments

Comments
 (0)