Skip to content

Commit 697f686

Browse files
authored
fix: None value handling of flattened generation kwargs for AmazonBedrockChatGenerator (#2752)
* fix: None value handling of flattened generation kwargs for AmazonBedrockChatGenerator * fmt * apply feedback
1 parent acaf0e4 commit 697f686

2 files changed

Lines changed: 72 additions & 21 deletions

File tree

integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -423,27 +423,31 @@ def _prepare_request_params(
423423

424424
def _resolve_flattened_generation_kwargs(self, generation_kwargs: dict[str, Any]) -> dict[str, Any]:
425425
generation_kwargs = generation_kwargs.copy()
426-
if "disable_parallel_tool_use" in generation_kwargs:
427-
disable_parallel_tool_use = generation_kwargs.pop("disable_parallel_tool_use")
428-
tool_choice = generation_kwargs.setdefault("tool_choice", {})
429-
tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use
430426

431-
if "parallel_tool_use" in generation_kwargs:
432-
parallel_tool_use = generation_kwargs.pop("parallel_tool_use")
427+
disable_parallel_tool_use = generation_kwargs.pop("disable_parallel_tool_use", None)
428+
parallel_tool_use = generation_kwargs.pop("parallel_tool_use", None)
429+
430+
if disable_parallel_tool_use is not None and parallel_tool_use is not None:
431+
msg = "Cannot set both disable_parallel_tool_use and parallel_tool_use"
432+
raise ValueError(msg)
433+
elif parallel_tool_use is not None:
433434
disable_parallel_tool_use = not parallel_tool_use
435+
436+
if disable_parallel_tool_use is not None:
434437
tool_choice = generation_kwargs.setdefault("tool_choice", {})
435438
tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use
439+
tool_choice.setdefault("type", "auto") # default value
436440

437-
if "tool_choice_type" in generation_kwargs:
438-
tool_choice_type = generation_kwargs.pop("tool_choice_type")
441+
tool_choice_type = generation_kwargs.pop("tool_choice_type", None)
442+
if tool_choice_type is not None:
439443
tool_choice = generation_kwargs.setdefault("tool_choice", {})
440444
tool_choice["type"] = tool_choice_type
441445

442-
if "thinking_budget_tokens" in generation_kwargs:
443-
thinking_budget_tokens = generation_kwargs.pop("thinking_budget_tokens")
446+
thinking_budget_tokens = generation_kwargs.pop("thinking_budget_tokens", None)
447+
if thinking_budget_tokens is not None:
444448
thinking = generation_kwargs.setdefault("thinking", {})
445449
thinking["budget_tokens"] = thinking_budget_tokens
446-
thinking["type"] = "enabled"
450+
thinking.setdefault("type", "enabled")
447451

448452
return generation_kwargs
449453

integrations/amazon_bedrock/tests/test_chat_generator.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -409,21 +409,68 @@ def tool_fn(city: str) -> str:
409409
]
410410
}
411411

412-
def test_prepare_request_params_with_flattened_generation_kwargs(self, mock_boto3_session, set_env_variables):
412+
@pytest.mark.parametrize(
413+
"generation_kwargs,additional_model_request_fields",
414+
[
415+
(
416+
{
417+
"parallel_tool_use": False,
418+
"tool_choice_type": "any",
419+
"thinking_budget_tokens": 1024,
420+
},
421+
{
422+
"tool_choice": {"disable_parallel_tool_use": True, "type": "any"},
423+
"thinking": {"budget_tokens": 1024, "type": "enabled"},
424+
},
425+
),
426+
(
427+
{
428+
"parallel_tool_use": True,
429+
"tool_choice_type": "all",
430+
},
431+
{
432+
"tool_choice": {"disable_parallel_tool_use": False, "type": "all"},
433+
},
434+
),
435+
(
436+
{
437+
"parallel_tool_use": True,
438+
},
439+
{
440+
"tool_choice": {"disable_parallel_tool_use": False, "type": "auto"},
441+
},
442+
),
443+
(
444+
{
445+
"disable_parallel_tool_use": True,
446+
},
447+
{
448+
"tool_choice": {"disable_parallel_tool_use": True, "type": "auto"},
449+
},
450+
),
451+
(
452+
{
453+
"thinking_budget_tokens": None,
454+
"parallel_tool_use": None,
455+
"tool_choice_type": None,
456+
},
457+
{},
458+
),
459+
],
460+
)
461+
def test_prepare_request_params_with_flattened_generation_kwargs(
462+
self, mock_boto3_session, set_env_variables, generation_kwargs, additional_model_request_fields
463+
):
413464
generator = AmazonBedrockChatGenerator(model="anthropic.claude-3-5-sonnet-20240620-v1:0")
414465
request_params, _ = generator._prepare_request_params(
415466
messages=[ChatMessage.from_user("What's the capital of France?")],
416-
generation_kwargs={
417-
"parallel_tool_use": False,
418-
"tool_choice_type": "any",
419-
"thinking_budget_tokens": 1024,
420-
},
467+
generation_kwargs=generation_kwargs,
421468
)
422469

423-
assert request_params["additionalModelRequestFields"] == {
424-
"tool_choice": {"disable_parallel_tool_use": True, "type": "any"},
425-
"thinking": {"budget_tokens": 1024, "type": "enabled"},
426-
}
470+
if not additional_model_request_fields:
471+
assert "additionalModelRequestFields" not in request_params
472+
else:
473+
assert request_params["additionalModelRequestFields"] == additional_model_request_fields
427474

428475

429476
# In the CI, those tests are skipped if AWS Authentication fails

0 commit comments

Comments
 (0)