Skip to content

Commit 12fd856

Browse files
authored
feat: serialize tool results as JSON when possible (#1752)
1 parent 32caa89 commit 12fd856

2 files changed

Lines changed: 149 additions & 11 deletions

File tree

src/strands/tools/decorator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
4343
import asyncio
4444
import functools
4545
import inspect
46+
import json
4647
import logging
4748
from collections.abc import Callable
4849
from typing import (
@@ -61,6 +62,7 @@ def my_tool(param1: str, param2: int = 42) -> dict:
6162
import docstring_parser
6263
from pydantic import BaseModel, Field, create_model
6364
from pydantic.fields import FieldInfo
65+
from pydantic_core import PydanticSerializationError
6466
from typing_extensions import override
6567

6668
from ..interrupt import InterruptException
@@ -644,12 +646,25 @@ def _wrap_tool_result(self, tool_use_d: str, result: Any, exception: Exception |
644646
return ToolResultEvent(cast(ToolResult, result), exception=exception)
645647
else:
646648
# Wrap any other return value in the standard format
647-
# Always include at least one content item for consistency
649+
# Serialize to JSON for consistent, parseable output (except strings)
650+
if isinstance(result, str):
651+
text = result
652+
elif isinstance(result, BaseModel):
653+
try:
654+
text = result.model_dump_json()
655+
except PydanticSerializationError:
656+
text = str(result)
657+
else:
658+
try:
659+
text = json.dumps(result)
660+
except (TypeError, ValueError):
661+
text = str(result)
662+
648663
return ToolResultEvent(
649664
{
650665
"toolUseId": tool_use_d,
651666
"status": "success",
652-
"content": [{"text": str(result)}],
667+
"content": [{"text": text}],
653668
},
654669
exception=exception,
655670
)

tests/strands/tools/test_decorator.py

Lines changed: 132 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def identity(a: int, agent: dict = None):
136136

137137
tru_events = await alist(stream)
138138
exp_events = [
139-
ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]})
139+
ToolResultEvent({"toolUseId": "unknown", "status": "success", "content": [{"text": '[2, {"state": 1}]'}]})
140140
]
141141
assert tru_events == exp_events
142142

@@ -595,12 +595,12 @@ def none_return_tool(param: str) -> None:
595595
assert result["tool_result"]["status"] == "success"
596596
assert result["tool_result"]["content"][0]["text"] == "Result: test"
597597

598-
# Test None return - should still create valid ToolResult with "None" text
598+
# Test None return - should still create valid ToolResult with "null"
599599
stream = none_return_tool.stream(tool_use, {})
600600

601601
result = (await alist(stream))[-1]
602602
assert result["tool_result"]["status"] == "success"
603-
assert result["tool_result"]["content"][0]["text"] == "None"
603+
assert result["tool_result"]["content"][0]["text"] == "null"
604604

605605

606606
@pytest.mark.asyncio
@@ -861,7 +861,7 @@ def int_return_tool(param: str) -> int:
861861

862862
result = (await alist(stream))[-1]
863863
assert result["tool_result"]["status"] == "success"
864-
assert result["tool_result"]["content"][0]["text"] == "None"
864+
assert result["tool_result"]["content"][0]["text"] == "null"
865865

866866
# Define tool with Union return type
867867
@strands.tool
@@ -884,10 +884,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None:
884884

885885
result = (await alist(stream))[-1]
886886
assert result["tool_result"]["status"] == "success"
887-
assert (
888-
"{'key': 'value'}" in result["tool_result"]["content"][0]["text"]
889-
or '{"key": "value"}' in result["tool_result"]["content"][0]["text"]
890-
)
887+
assert result["tool_result"]["content"][0]["text"] == '{"key": "value"}'
891888

892889
tool_use = {"toolUseId": "test-id", "input": {"param": "str"}}
893890
stream = union_return_tool.stream(tool_use, {})
@@ -901,7 +898,7 @@ def union_return_tool(param: str) -> dict[str, Any] | str | None:
901898

902899
result = (await alist(stream))[-1]
903900
assert result["tool_result"]["status"] == "success"
904-
assert result["tool_result"]["content"][0]["text"] == "None"
901+
assert result["tool_result"]["content"][0]["text"] == "null"
905902

906903

907904
@pytest.mark.asyncio
@@ -992,6 +989,132 @@ def custom_result_tool(param: str) -> dict[str, Any]:
992989
assert result["tool_result"]["content"][1]["type"] == "markdown"
993990

994991

992+
@pytest.mark.asyncio
993+
async def test_tool_result_json_serialization_dict(alist):
994+
"""Test that dict results are serialized as JSON."""
995+
996+
@strands.tool
997+
def dict_tool() -> dict:
998+
"""Returns a dict."""
999+
return {"key": "value", "number": 42}
1000+
1001+
tool_use = {"toolUseId": "test-id", "input": {}}
1002+
stream = dict_tool.stream(tool_use, {})
1003+
1004+
result = (await alist(stream))[-1]
1005+
text = result["tool_result"]["content"][0]["text"]
1006+
1007+
assert text == '{"key": "value", "number": 42}'
1008+
1009+
1010+
@pytest.mark.asyncio
1011+
async def test_tool_result_json_serialization_list(alist):
1012+
"""Test that list results are serialized as JSON."""
1013+
1014+
@strands.tool
1015+
def list_tool() -> list:
1016+
"""Returns a list."""
1017+
return [1, "two", {"three": 3}]
1018+
1019+
tool_use = {"toolUseId": "test-id", "input": {}}
1020+
stream = list_tool.stream(tool_use, {})
1021+
1022+
result = (await alist(stream))[-1]
1023+
text = result["tool_result"]["content"][0]["text"]
1024+
1025+
assert text == '[1, "two", {"three": 3}]'
1026+
1027+
1028+
@pytest.mark.asyncio
1029+
async def test_tool_result_json_serialization_pydantic(alist):
1030+
"""Test that Pydantic model results are serialized as JSON."""
1031+
from pydantic import BaseModel
1032+
1033+
class MyModel(BaseModel):
1034+
name: str
1035+
count: int
1036+
1037+
@strands.tool
1038+
def pydantic_tool() -> MyModel:
1039+
"""Returns a Pydantic model."""
1040+
return MyModel(name="test", count=5)
1041+
1042+
tool_use = {"toolUseId": "test-id", "input": {}}
1043+
stream = pydantic_tool.stream(tool_use, {})
1044+
1045+
result = (await alist(stream))[-1]
1046+
text = result["tool_result"]["content"][0]["text"]
1047+
1048+
assert text == '{"name":"test","count":5}'
1049+
1050+
1051+
@pytest.mark.asyncio
1052+
async def test_tool_result_json_serialization_pydantic_non_serializable(alist):
1053+
"""Test that Pydantic models with non-serializable fields fall back to str()."""
1054+
from pydantic import BaseModel
1055+
1056+
class NonSerializable:
1057+
def __repr__(self):
1058+
return "NonSerializable()"
1059+
1060+
class MyModel(BaseModel):
1061+
model_config = {"arbitrary_types_allowed": True}
1062+
data: NonSerializable
1063+
1064+
@strands.tool
1065+
def pydantic_tool() -> MyModel:
1066+
"""Returns a Pydantic model with non-serializable field."""
1067+
return MyModel(data=NonSerializable())
1068+
1069+
tool_use = {"toolUseId": "test-id", "input": {}}
1070+
stream = pydantic_tool.stream(tool_use, {})
1071+
1072+
result = (await alist(stream))[-1]
1073+
text = result["tool_result"]["content"][0]["text"]
1074+
1075+
assert text == "data=NonSerializable()"
1076+
1077+
1078+
@pytest.mark.asyncio
1079+
async def test_tool_result_json_serialization_non_serializable(alist):
1080+
"""Test that non-JSON-serializable results fall back to str()."""
1081+
1082+
class CustomClass:
1083+
def __str__(self):
1084+
return "custom_str_repr"
1085+
1086+
@strands.tool
1087+
def custom_tool() -> Any:
1088+
"""Returns a non-serializable object."""
1089+
return CustomClass()
1090+
1091+
tool_use = {"toolUseId": "test-id", "input": {}}
1092+
stream = custom_tool.stream(tool_use, {})
1093+
1094+
result = (await alist(stream))[-1]
1095+
text = result["tool_result"]["content"][0]["text"]
1096+
1097+
assert text == "custom_str_repr"
1098+
1099+
1100+
@pytest.mark.asyncio
1101+
async def test_tool_result_string_not_json_encoded(alist):
1102+
"""Test that string results are NOT JSON-encoded (no extra quotes)."""
1103+
1104+
@strands.tool
1105+
def string_tool() -> str:
1106+
"""Returns a string."""
1107+
return "hello world"
1108+
1109+
tool_use = {"toolUseId": "test-id", "input": {}}
1110+
stream = string_tool.stream(tool_use, {})
1111+
1112+
result = (await alist(stream))[-1]
1113+
text = result["tool_result"]["content"][0]["text"]
1114+
1115+
assert text == "hello world"
1116+
1117+
9951118
def test_docstring_parsing():
9961119
"""Test that function docstring is correctly parsed into tool spec."""
9971120

0 commit comments

Comments
 (0)