Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
import docstring_parser
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticSerializationError
from pydantic_core import PydanticSerializationError, PydanticUndefined
from typing_extensions import override

from ..interrupt import InterruptException
Expand Down Expand Up @@ -159,12 +159,28 @@ def _extract_annotated_metadata(
)

# Determine the final description with a clear priority order
# Priority: 1. Annotated string -> 2. Docstring -> 3. Fallback
# Priority: 1. Annotated string -> 2. Docstring -> 3. Field description -> 4. Fallback
final_description = description
if final_description is None:
final_description = self.param_descriptions.get(param_name) or f"Parameter {param_name}"
# Create FieldInfo object from scratch
final_field = Field(default=param_default, description=final_description)
field_description = param_default.description if isinstance(param_default, FieldInfo) else None
final_description = (
self.param_descriptions.get(param_name) or field_description or f"Parameter {param_name}"
)

# Create FieldInfo object from scratch, correctly handling Field() defaults.
# When the caller uses `param: T = Field(default_factory=list)`, param_default is a
# FieldInfo whose .default is PydanticUndefined. Passing a FieldInfo as `default=` to
# a new Field() triggers a PydanticJsonSchemaWarning because FieldInfo is not JSON
# serializable, so we unwrap it and forward the actual default or factory instead.
if isinstance(param_default, FieldInfo):
if param_default.default_factory is not None:
final_field = Field(default_factory=param_default.default_factory, description=final_description)
elif param_default.default is not PydanticUndefined:
final_field = Field(default=param_default.default, description=final_description)
else:
final_field = Field(description=final_description)
else:
final_field = Field(default=param_default, description=final_description)

return actual_type, final_field

Expand Down
41 changes: 41 additions & 0 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,3 +2101,44 @@ def my_tool(name: str, tag: str | None = None) -> str:
# Since tag is not required, anyOf should be simplified away
assert "anyOf" not in schema["properties"]["tag"]
assert schema["properties"]["tag"]["type"] == "string"


def test_tool_field_default_factory_no_warning(recwarn):
"""Field(default_factory=...) as a parameter default must not emit PydanticJsonSchemaWarning."""

@strands.tool
def my_tool(items: list[str] = Field(default_factory=list, description="items")) -> int: # noqa: B008
"""A tool."""
return len(items)

schema_warnings = [w for w in recwarn.list if "not JSON serializable" in str(w.message)]
assert schema_warnings == [], "PydanticJsonSchemaWarning should not be emitted for default_factory"


def test_tool_field_default_factory_description_used():
"""Description provided via Field(description=...) is included in the tool schema."""

@strands.tool
def my_tool(items: list[str] = Field(default_factory=list, description="the items list")) -> int: # noqa: B008
"""A tool."""
return len(items)

schema = my_tool.tool_spec["inputSchema"]["json"]
assert schema["properties"]["items"]["description"] == "the items list"
assert "items" not in schema.get("required", [])


def test_tool_field_default_value_no_warning(recwarn):
"""Field(default=...) as a parameter default must not emit PydanticJsonSchemaWarning."""

@strands.tool
def my_tool(count: int = Field(default=5, description="the count")) -> int: # noqa: B008
"""A tool."""
return count

schema_warnings = [w for w in recwarn.list if "not JSON serializable" in str(w.message)]
assert schema_warnings == [], "PydanticJsonSchemaWarning should not be emitted for Field(default=...)"

schema = my_tool.tool_spec["inputSchema"]["json"]
assert schema["properties"]["count"]["description"] == "the count"
assert "count" not in schema.get("required", [])