@@ -409,21 +409,68 @@ def tool_fn(city: str) -> str:
409409 ]
410410 }
411411
412- def test_prepare_request_params_with_flattened_generation_kwargs (self , mock_boto3_session , set_env_variables ):
412+ @pytest .mark .parametrize (
413+ "generation_kwargs,additional_model_request_fields" ,
414+ [
415+ (
416+ {
417+ "parallel_tool_use" : False ,
418+ "tool_choice_type" : "any" ,
419+ "thinking_budget_tokens" : 1024 ,
420+ },
421+ {
422+ "tool_choice" : {"disable_parallel_tool_use" : True , "type" : "any" },
423+ "thinking" : {"budget_tokens" : 1024 , "type" : "enabled" },
424+ },
425+ ),
426+ (
427+ {
428+ "parallel_tool_use" : True ,
429+ "tool_choice_type" : "all" ,
430+ },
431+ {
432+ "tool_choice" : {"disable_parallel_tool_use" : False , "type" : "all" },
433+ },
434+ ),
435+ (
436+ {
437+ "parallel_tool_use" : True ,
438+ },
439+ {
440+ "tool_choice" : {"disable_parallel_tool_use" : False , "type" : "auto" },
441+ },
442+ ),
443+ (
444+ {
445+ "disable_parallel_tool_use" : True ,
446+ },
447+ {
448+ "tool_choice" : {"disable_parallel_tool_use" : True , "type" : "auto" },
449+ },
450+ ),
451+ (
452+ {
453+ "thinking_budget_tokens" : None ,
454+ "parallel_tool_use" : None ,
455+ "tool_choice_type" : None ,
456+ },
457+ {},
458+ ),
459+ ],
460+ )
461+ def test_prepare_request_params_with_flattened_generation_kwargs (
462+ self , mock_boto3_session , set_env_variables , generation_kwargs , additional_model_request_fields
463+ ):
413464 generator = AmazonBedrockChatGenerator (model = "anthropic.claude-3-5-sonnet-20240620-v1:0" )
414465 request_params , _ = generator ._prepare_request_params (
415466 messages = [ChatMessage .from_user ("What's the capital of France?" )],
416- generation_kwargs = {
417- "parallel_tool_use" : False ,
418- "tool_choice_type" : "any" ,
419- "thinking_budget_tokens" : 1024 ,
420- },
467+ generation_kwargs = generation_kwargs ,
421468 )
422469
423- assert request_params [ "additionalModelRequestFields" ] == {
424- "tool_choice" : { "disable_parallel_tool_use" : True , "type" : "any" },
425- "thinking" : { "budget_tokens" : 1024 , "type" : "enabled" },
426- }
470+ if not additional_model_request_fields :
471+ assert "additionalModelRequestFields" not in request_params
472+ else :
473+ assert request_params [ "additionalModelRequestFields" ] == additional_model_request_fields
427474
428475
429476# In the CI, those tests are skipped if AWS Authentication fails
0 commit comments