|
28 | 28 | from google.adk.models.anthropic_llm import content_to_message_param |
29 | 29 | from google.adk.models.anthropic_llm import function_declaration_to_tool_param |
30 | 30 | from google.adk.models.anthropic_llm import part_to_message_block |
| 31 | +from google.adk.models.anthropic_llm import to_google_genai_finish_reason |
31 | 32 | from google.adk.models.llm_request import LlmRequest |
32 | 33 | from google.adk.models.llm_response import LlmResponse |
33 | 34 | from google.genai import types |
@@ -139,6 +140,23 @@ def test_supported_models(): |
139 | 140 | assert models[1] == r"claude-.*-4.*" |
140 | 141 |
|
141 | 142 |
|
| 143 | +@pytest.mark.parametrize( |
| 144 | + ("stop_reason", "expected"), |
| 145 | + [ |
| 146 | + ("end_turn", types.FinishReason.STOP), |
| 147 | + ("stop_sequence", types.FinishReason.STOP), |
| 148 | + ("tool_use", types.FinishReason.STOP), |
| 149 | + ("pause_turn", types.FinishReason.STOP), |
| 150 | + ("max_tokens", types.FinishReason.MAX_TOKENS), |
| 151 | + ("refusal", types.FinishReason.SAFETY), |
| 152 | + (None, types.FinishReason.FINISH_REASON_UNSPECIFIED), |
| 153 | + ("unknown_reason", types.FinishReason.FINISH_REASON_UNSPECIFIED), |
| 154 | + ], |
| 155 | +) |
| 156 | +def test_to_google_genai_finish_reason(stop_reason, expected): |
| 157 | + assert to_google_genai_finish_reason(stop_reason) == expected |
| 158 | + |
| 159 | + |
142 | 160 | function_declaration_test_cases = [ |
143 | 161 | ( |
144 | 162 | "function_with_no_parameters", |
@@ -1350,6 +1368,54 @@ async def test_non_streaming_does_not_pass_stream_param(): |
1350 | 1368 | mock_client.messages.create.assert_called_once() |
1351 | 1369 | _, kwargs = mock_client.messages.create.call_args |
1352 | 1370 | assert "stream" not in kwargs |
| 1371 | + assert responses[0].finish_reason == types.FinishReason.STOP |
| 1372 | + |
| 1373 | + |
| 1374 | +@pytest.mark.asyncio |
| 1375 | +async def test_streaming_sets_finish_reason_from_message_delta(): |
| 1376 | + llm = AnthropicLlm(model="claude-sonnet-4-20250514") |
| 1377 | + |
| 1378 | + events = [ |
| 1379 | + MagicMock( |
| 1380 | + type="message_start", |
| 1381 | + message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)), |
| 1382 | + ), |
| 1383 | + MagicMock( |
| 1384 | + type="content_block_start", |
| 1385 | + index=0, |
| 1386 | + content_block=anthropic_types.TextBlock(text="", type="text"), |
| 1387 | + ), |
| 1388 | + MagicMock( |
| 1389 | + type="content_block_delta", |
| 1390 | + index=0, |
| 1391 | + delta=anthropic_types.TextDelta(text="Hello", type="text_delta"), |
| 1392 | + ), |
| 1393 | + MagicMock(type="content_block_stop", index=0), |
| 1394 | + MagicMock( |
| 1395 | + type="message_delta", |
| 1396 | + delta=MagicMock(stop_reason="max_tokens"), |
| 1397 | + usage=MagicMock(output_tokens=3), |
| 1398 | + ), |
| 1399 | + MagicMock(type="message_stop"), |
| 1400 | + ] |
| 1401 | + |
| 1402 | + mock_client = MagicMock() |
| 1403 | + mock_client.messages.create = AsyncMock( |
| 1404 | + return_value=_make_mock_stream_events(events) |
| 1405 | + ) |
| 1406 | + |
| 1407 | + llm_request = LlmRequest( |
| 1408 | + model="claude-sonnet-4-20250514", |
| 1409 | + contents=[Content(role="user", parts=[Part.from_text(text="Hi")])], |
| 1410 | + config=types.GenerateContentConfig(system_instruction="Test"), |
| 1411 | + ) |
| 1412 | + |
| 1413 | + with mock.patch.object(llm, "_anthropic_client", mock_client): |
| 1414 | + responses = [ |
| 1415 | + r async for r in llm.generate_content_async(llm_request, stream=True) |
| 1416 | + ] |
| 1417 | + |
| 1418 | + assert responses[-1].finish_reason == types.FinishReason.MAX_TOKENS |
1353 | 1419 |
|
1354 | 1420 |
|
1355 | 1421 | def test_part_to_message_block_function_call_preserves_valid_id(): |
|
0 commit comments