Skip to content

Commit 781d57a

Browse files
committed
fix(models): forward Anthropic generation config
1 parent 60b9073 commit 781d57a

2 files changed

Lines changed: 124 additions & 0 deletions

File tree

src/google/adk/models/anthropic_llm.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,22 @@ def _resolve_model_name(self, model: Optional[str]) -> str:
377377
return match.group(1)
378378
return model
379379

380+
def _get_generation_kwargs(
381+
self, llm_request: LlmRequest
382+
) -> dict[str, Any]:
383+
generation_kwargs: dict[str, Any] = {}
384+
385+
if llm_request.config.temperature is not None:
386+
generation_kwargs["temperature"] = llm_request.config.temperature
387+
if llm_request.config.top_p is not None:
388+
generation_kwargs["top_p"] = llm_request.config.top_p
389+
if llm_request.config.top_k is not None:
390+
generation_kwargs["top_k"] = llm_request.config.top_k
391+
if llm_request.config.stop_sequences:
392+
generation_kwargs["stop_sequences"] = llm_request.config.stop_sequences
393+
394+
return generation_kwargs
395+
380396
@override
381397
async def generate_content_async(
382398
self, llm_request: LlmRequest, stream: bool = False
@@ -401,6 +417,7 @@ async def generate_content_async(
401417
if llm_request.tools_dict
402418
else NOT_GIVEN
403419
)
420+
generation_kwargs = self._get_generation_kwargs(llm_request)
404421

405422
if not stream:
406423
message = await self._anthropic_client.messages.create(
@@ -410,6 +427,7 @@ async def generate_content_async(
410427
tools=tools,
411428
tool_choice=tool_choice,
412429
max_tokens=self.max_tokens,
430+
**generation_kwargs,
413431
)
414432
yield message_to_generate_content_response(message)
415433
else:
@@ -439,6 +457,7 @@ async def _generate_content_streaming(
439457
tool_choice=tool_choice,
440458
max_tokens=self.max_tokens,
441459
stream=True,
460+
**self._get_generation_kwargs(llm_request),
442461
)
443462

444463
# Track content blocks being built during streaming.

tests/unittests/models/test_anthropic_llm.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,3 +1350,108 @@ async def test_non_streaming_does_not_pass_stream_param():
13501350
mock_client.messages.create.assert_called_once()
13511351
_, kwargs = mock_client.messages.create.call_args
13521352
assert "stream" not in kwargs
1353+
1354+
1355+
@pytest.mark.asyncio
1356+
async def test_non_streaming_forwards_generation_config_kwargs():
1357+
llm = AnthropicLlm(model="claude-sonnet-4-20250514")
1358+
1359+
mock_message = anthropic_types.Message(
1360+
id="msg_test",
1361+
content=[
1362+
anthropic_types.TextBlock(text="Hello!", type="text", citations=None)
1363+
],
1364+
model="claude-sonnet-4-20250514",
1365+
role="assistant",
1366+
stop_reason="end_turn",
1367+
stop_sequence=None,
1368+
type="message",
1369+
usage=anthropic_types.Usage(
1370+
input_tokens=5,
1371+
output_tokens=2,
1372+
cache_creation_input_tokens=0,
1373+
cache_read_input_tokens=0,
1374+
server_tool_use=None,
1375+
service_tier=None,
1376+
),
1377+
)
1378+
1379+
mock_client = MagicMock()
1380+
mock_client.messages.create = AsyncMock(return_value=mock_message)
1381+
1382+
llm_request = LlmRequest(
1383+
model="claude-sonnet-4-20250514",
1384+
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
1385+
config=types.GenerateContentConfig(
1386+
system_instruction="Test",
1387+
temperature=0.0,
1388+
top_p=0.8,
1389+
top_k=12,
1390+
stop_sequences=["DONE"],
1391+
),
1392+
)
1393+
1394+
with mock.patch.object(llm, "_anthropic_client", mock_client):
1395+
_ = [r async for r in llm.generate_content_async(llm_request, stream=False)]
1396+
1397+
_, kwargs = mock_client.messages.create.call_args
1398+
assert kwargs["temperature"] == 0.0
1399+
assert kwargs["top_p"] == 0.8
1400+
assert kwargs["top_k"] == 12
1401+
assert kwargs["stop_sequences"] == ["DONE"]
1402+
1403+
1404+
@pytest.mark.asyncio
1405+
async def test_streaming_forwards_generation_config_kwargs():
1406+
llm = AnthropicLlm(model="claude-sonnet-4-20250514")
1407+
1408+
events = [
1409+
MagicMock(
1410+
type="message_start",
1411+
message=MagicMock(usage=MagicMock(input_tokens=5, output_tokens=0)),
1412+
),
1413+
MagicMock(
1414+
type="content_block_start",
1415+
index=0,
1416+
content_block=anthropic_types.TextBlock(text="", type="text"),
1417+
),
1418+
MagicMock(
1419+
type="content_block_delta",
1420+
index=0,
1421+
delta=anthropic_types.TextDelta(text="Hi", type="text_delta"),
1422+
),
1423+
MagicMock(type="content_block_stop", index=0),
1424+
MagicMock(
1425+
type="message_delta",
1426+
delta=MagicMock(stop_reason="end_turn"),
1427+
usage=MagicMock(output_tokens=1),
1428+
),
1429+
MagicMock(type="message_stop"),
1430+
]
1431+
1432+
mock_client = MagicMock()
1433+
mock_client.messages.create = AsyncMock(
1434+
return_value=_make_mock_stream_events(events)
1435+
)
1436+
1437+
llm_request = LlmRequest(
1438+
model="claude-sonnet-4-20250514",
1439+
contents=[Content(role="user", parts=[Part.from_text(text="Hi")])],
1440+
config=types.GenerateContentConfig(
1441+
system_instruction="Test",
1442+
temperature=0.2,
1443+
top_p=0.7,
1444+
top_k=8,
1445+
stop_sequences=["STOP_HERE"],
1446+
),
1447+
)
1448+
1449+
with mock.patch.object(llm, "_anthropic_client", mock_client):
1450+
_ = [r async for r in llm.generate_content_async(llm_request, stream=True)]
1451+
1452+
_, kwargs = mock_client.messages.create.call_args
1453+
assert kwargs["stream"] is True
1454+
assert kwargs["temperature"] == 0.2
1455+
assert kwargs["top_p"] == 0.7
1456+
assert kwargs["top_k"] == 8
1457+
assert kwargs["stop_sequences"] == ["STOP_HERE"]

0 commit comments

Comments
 (0)