diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 54fdaaf00..4d7b6870f 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -20,6 +20,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolChoiceToolDict, ToolSpec +from ._strict_schema import ensure_strict_json_schema from ._validation import _has_location_source, validate_config_keys from .model import BaseModelConfig, Model @@ -231,7 +232,12 @@ def format_request( { "name": tool_spec["name"], "description": tool_spec["description"], - "input_schema": tool_spec["inputSchema"]["json"], + "input_schema": ( + ensure_strict_json_schema(tool_spec["inputSchema"]["json"]) + if tool_spec.get("strict") + else tool_spec["inputSchema"]["json"] + ), + **({"strict": tool_spec["strict"]} if "strict" in tool_spec else {}), } for tool_spec in tool_specs or [] ], diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index d535bbc51..6efb154bb 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -270,10 +270,14 @@ def _format_request( "description": tool_spec["description"], "inputSchema": ( {"json": ensure_strict_json_schema(tool_spec["inputSchema"]["json"])} - if self.config.get("strict_tools") + if tool_spec.get("strict", self.config.get("strict_tools")) else tool_spec["inputSchema"] ), - **({"strict": True} if self.config.get("strict_tools") else {}), + **( + {"strict": True} + if tool_spec.get("strict", self.config.get("strict_tools")) + else {} + ), } } for tool_spec in tool_specs diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index c4be7d360..082c9bffd 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -21,6 +21,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse +from ._strict_schema import ensure_strict_json_schema from ._validation import _has_location_source, validate_config_keys from .model import BaseModelConfig, Model @@ -481,7 +482,12 @@ def format_request( "function": { "name": tool_spec["name"], "description": tool_spec["description"], - "parameters": tool_spec["inputSchema"]["json"], + "parameters": ( + ensure_strict_json_schema(tool_spec["inputSchema"]["json"], require_all_properties=True) + if tool_spec.get("strict") + else tool_spec["inputSchema"]["json"] + ), + **({"strict": tool_spec["strict"]} if "strict" in tool_spec else {}), }, } for tool_spec in tool_specs or [] diff --git a/src/strands/models/openai_responses.py b/src/strands/models/openai_responses.py index 73a889aad..9ce1fb20a 100644 --- a/src/strands/models/openai_responses.py +++ b/src/strands/models/openai_responses.py @@ -58,6 +58,7 @@ from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException # noqa: E402 from ..types.streaming import StreamEvent # noqa: E402 from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse # noqa: E402 +from ._strict_schema import ensure_strict_json_schema # noqa: E402 from ._validation import validate_config_keys # noqa: E402 from .model import BaseModelConfig, Model # noqa: E402 @@ -516,7 +517,12 @@ def _format_request( "type": "function", "name": tool_spec["name"], "description": tool_spec.get("description", ""), - "parameters": tool_spec["inputSchema"]["json"], + "parameters": ( + ensure_strict_json_schema(tool_spec["inputSchema"]["json"], require_all_properties=True) + if tool_spec.get("strict") + else tool_spec["inputSchema"]["json"] + ), + **({"strict": tool_spec["strict"]} if "strict" in tool_spec else {}), } for tool_spec in tool_specs ) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 9207df9b8..b1c30c821 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -725,6 +725,7 @@ def tool( inputSchema: JSONSchema | None = None, name: str | None = None, context: bool | str = False, + strict: bool | None = None, ) -> Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: ... # Suppressing the type error because we want callers to be able to use both `tool` and `tool()` at the # call site, but the actual implementation handles that and it's not representable via the type-system @@ -734,6 +735,7 @@ def tool( # type: ignore inputSchema: JSONSchema | None = None, name: str | None = None, context: bool | str = False, + strict: bool | None = None, ) -> DecoratedFunctionTool[P, R] | Callable[[Callable[P, R]], DecoratedFunctionTool[P, R]]: """Decorator that transforms a Python function into a Strands tool. @@ -762,6 +764,10 @@ def tool( # type: ignore context: When provided, places an object in the designated parameter. If True, the param name defaults to 'tool_context', or if an override is needed, set context equal to a string to designate the param name. + strict: Optional Boolean that ensures the model will only output tool calls containing parameters + that perfectly match the defined input schema. Note: When using strict mode, optional parameters + must be explicitly typed as nullable (e.g., `Optional[str]`), otherwise the model will be forced + to generate a value for them. Returns: An AgentTool that also mimics the original function when invoked @@ -816,6 +822,8 @@ def decorator(f: T) -> "DecoratedFunctionTool[P, R]": tool_spec["description"] = description if inputSchema is not None: tool_spec["inputSchema"] = inputSchema + if strict is not None: + tool_spec["strict"] = strict tool_name = tool_spec.get("name", f.__name__) diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 088c83bdb..5107a0c5b 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -30,12 +30,19 @@ class ToolSpec(TypedDict): outputSchema: Optional JSON Schema defining the expected output format. Note: Not all model providers support this field. Providers that don't support it should filter it out before sending to their API. + strict: Optional Boolean that ensures the model will only output tool calls + containing parameters that perfectly match the defined input schema. + Note: When using strict mode, optional parameters must be explicitly typed + as nullable (e.g., `Optional[str]`), otherwise the model will be forced + to generate a value for them. Not all model providers support this field. + Providers that don't support it should filter it out before sending to their API. """ description: str inputSchema: JSONSchema name: str outputSchema: NotRequired[JSONSchema] + strict: NotRequired[bool] class Tool(TypedDict): diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 8e004dbb7..24055e30d 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -459,6 +459,41 @@ def test_format_request_tool_choice_auto(model, messages, model_id, max_tokens): assert tru_request == exp_request +def test_format_request_tool_specs_with_strict(model, messages, model_id, max_tokens): + strict_tool_spec = { + "description": "description", + "name": "name", + "inputSchema": { + "json": { + "type": "object", + "properties": {"x": {"type": "string"}}, + } + }, + "strict": True, + } + tru_request = model.format_request(messages, tool_specs=[strict_tool_spec]) + tool_in_request = tru_request["tools"][0] + + assert tool_in_request["strict"] is True + assert tool_in_request["input_schema"]["additionalProperties"] is False + + +def test_format_request_tool_specs_with_strict_false(model, messages, model_id, max_tokens): + tool_specs = [ + {"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}, "strict": False} + ] + tru_request = model.format_request(messages, tool_specs) + + assert tru_request["tools"][0]["strict"] is False + + +def test_format_request_tool_specs_without_strict(model, messages, model_id, max_tokens): + tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] + tru_request = model.format_request(messages, tool_specs) + + assert "strict" not in tru_request["tools"][0] + + def test_format_request_tool_choice_any(model, messages, model_id, max_tokens): tool_specs = [{"description": "test tool", "name": "test_tool", "inputSchema": {"json": {"key": "value"}}}] tool_choice = {"any": {}} diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index a80ca091e..cabe73887 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -675,6 +675,68 @@ def test_format_request_strict_tools_applies_to_all_tools(bedrock_client, model_ assert tool["toolSpec"]["inputSchema"]["json"]["additionalProperties"] is False +def test_format_request_strict_tools_overridden_by_tool_spec(bedrock_client, model_id, messages): + tool_specs = [ + {"name": "strict_tool", "description": "Tool A", "inputSchema": {"json": {"type": "object", "properties": {}}}}, + { + "name": "non_strict_tool", + "description": "Tool B", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + "strict": False, + }, + ] + model = BedrockModel(model_id=model_id, strict_tools=True) + request = model._format_request(messages, tool_specs=tool_specs) + + tool_a = request["toolConfig"]["tools"][0]["toolSpec"] + tool_b = request["toolConfig"]["tools"][1]["toolSpec"] + + assert tool_a.get("strict") is True + assert tool_a["inputSchema"]["json"]["additionalProperties"] is False + + assert "strict" not in tool_b + assert "additionalProperties" not in tool_b["inputSchema"]["json"] + + +def test_format_request_tool_specs_with_strict(model, messages, model_id): + strict_tool_spec = { + "description": "description", + "name": "name", + "inputSchema": {"json": {"type": "object", "properties": {"key": {"type": "string"}}}}, + "strict": True, + } + tru_request = model._format_request(messages, tool_specs=[strict_tool_spec]) + tool_in_request = tru_request["toolConfig"]["tools"][0]["toolSpec"] + + assert tool_in_request["strict"] is True + assert tool_in_request["inputSchema"]["json"]["additionalProperties"] is False + + +def test_format_request_tool_specs_with_strict_false(model, messages, model_id): + strict_false_tool_spec = { + "description": "description", + "name": "name", + "inputSchema": {"json": {"type": "object", "properties": {"key": {"type": "string"}}}}, + "strict": False, + } + tru_request = model._format_request(messages, tool_specs=[strict_false_tool_spec]) + tool_in_request = tru_request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_in_request + + +def test_format_request_tool_specs_without_strict(model, messages, model_id): + tool_spec_no_strict = { + "description": "description", + "name": "name", + "inputSchema": {"json": {"type": "object", "properties": {"key": {"type": "string"}}}}, + } + tru_request = model._format_request(messages, tool_specs=[tool_spec_no_strict]) + tool_in_request = tru_request["toolConfig"]["tools"][0]["toolSpec"] + + assert "strict" not in tool_in_request + + def test_format_request_tool_choice_auto(model, messages, model_id, tool_spec): tool_choice = {"auto": {}} tru_request = model._format_request(messages, [tool_spec], tool_choice=tool_choice) diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 94e4caa3f..68551a3a1 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -636,6 +636,51 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_tool_specs_with_strict(model, messages, system_prompt): + strict_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"input": {"type": "string"}, "optional_input": {"type": "string"}}, + "required": ["input"], + } + }, + "strict": True, + } + ] + tru_request = model.format_request(messages, strict_tool_specs, system_prompt) + + tool_function = tru_request["tools"][0]["function"] + assert tool_function["strict"] is True + assert tool_function["parameters"]["additionalProperties"] is False + assert set(tool_function["parameters"]["required"]) == {"input", "optional_input"} + + +def test_format_request_tool_specs_with_strict_false(model, messages, system_prompt): + strict_false_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + "strict": False, + } + ] + tru_request = model.format_request(messages, strict_false_tool_specs, system_prompt) + + assert tru_request["tools"][0]["function"]["strict"] is False + + +def test_format_request_tool_specs_without_strict(model, messages, tool_specs, system_prompt): + tru_request = model.format_request(messages, tool_specs, system_prompt) + + assert "strict" not in tru_request["tools"][0]["function"] + + def test_format_request_with_tool_choice_auto(model, messages, tool_specs, system_prompt): tool_choice = {"auto": {}} tru_request = model.format_request(messages, tool_specs, system_prompt, tool_choice) diff --git a/tests/strands/models/test_openai_responses.py b/tests/strands/models/test_openai_responses.py index 88cbee326..8fdb7bea1 100644 --- a/tests/strands/models/test_openai_responses.py +++ b/tests/strands/models/test_openai_responses.py @@ -349,6 +349,60 @@ def test_format_request(model, messages, tool_specs, system_prompt): assert tru_request == exp_request +def test_format_request_tool_specs_with_strict(model, messages, system_prompt): + strict_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"input": {"type": "string"}, "optional_input": {"type": "string"}}, + "required": ["input"], + } + }, + "strict": True, + } + ] + tru_request = model._format_request(messages, strict_tool_specs, system_prompt) + + tool = tru_request["tools"][0] + assert tool["strict"] is True + assert tool["parameters"]["additionalProperties"] is False + assert set(tool["parameters"]["required"]) == {"input", "optional_input"} + + +def test_format_request_tool_specs_with_strict_false(model, messages, system_prompt): + strict_false_tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + "strict": False, + } + ] + tru_request = model._format_request(messages, strict_false_tool_specs, system_prompt) + + assert tru_request["tools"][0]["strict"] is False + + +def test_format_request_tool_specs_without_strict(model, messages, system_prompt): + tool_specs = [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": {"type": "object", "properties": {"input": {"type": "string"}}, "required": ["input"]} + }, + } + ] + tru_request = model._format_request(messages, tool_specs, system_prompt) + + assert "strict" not in tru_request["tools"][0] + + @pytest.mark.parametrize( ("event", "exp_chunk"), [ diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index cc1158983..579bc0a5b 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -90,6 +90,33 @@ def test_tool_spec(identity_tool): assert tru_spec == exp_spec +def test_tool_spec_with_strict_true(): + @strands.tool(strict=True) + def my_tool(param: str) -> str: + """A tool.""" + return param + + assert my_tool.tool_spec["strict"] is True + + +def test_tool_spec_with_strict_false(): + @strands.tool(strict=False) + def my_tool(param: str) -> str: + """A tool.""" + return param + + assert my_tool.tool_spec["strict"] is False + + +def test_tool_spec_without_strict(): + @strands.tool + def my_tool(param: str) -> str: + """A tool.""" + return param + + assert "strict" not in my_tool.tool_spec + + @pytest.mark.parametrize("identity_tool", ["identity_invoke", "identity_invoke_async"], indirect=True) def test_tool_type(identity_tool): tru_type = identity_tool.tool_type diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index a5eba45b9..a554be14e 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -221,3 +221,17 @@ async def test_count_tokens_with_tools_greater_than_without(self, model, message without = await model.count_tokens(messages=messages) with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") assert with_tools > without + + +def test_strict_tool_integration(model): + """Test that a strict tool invocation is accepted by the Anthropic API without a 400 validation error.""" + + @strands.tool(strict=True) + def strict_tool(text: str) -> str: + """A strict tool for testing.""" + return f"Echo: {text}" + + agent = Agent(model=model, tools=[strict_tool]) + # The API should accept the request and respond normally without throwing a 400 validation error + result = agent("Call the strict tool with the text 'hello'") + assert result.message is not None diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index bef526427..4e1dccec6 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -438,3 +438,17 @@ async def test_count_tokens_with_tools_greater_than_without(self, model, message without = await model.count_tokens(messages=messages) with_tools = await model.count_tokens(messages=messages, tool_specs=tool_specs, system_prompt="Be helpful.") assert with_tools > without + + +def test_strict_tool_integration(model): + """Test that a strict tool invocation is accepted by the OpenAI API without a 400 validation error.""" + + @strands.tool(strict=True) + def strict_tool(text: str) -> str: + """A strict tool for testing.""" + return f"Echo: {text}" + + agent = Agent(model=model, tools=[strict_tool]) + # The API should accept the request and respond normally without throwing a 400 validation error + result = agent("Call the strict tool with the text 'hello'") + assert result.message is not None