Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions haystack/tools/component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,14 @@ def to_dict(self) -> Dict[str, Any]:
"""
serialized_component = component_to_dict(obj=self._component, name=self.name)

serialized = {
serialized: Dict[str, Any] = {
"component": serialized_component,
"name": self.name,
"description": self.description,
"parameters": self._unresolved_parameters,
"outputs_to_string": self.outputs_to_string,
"inputs_from_state": self.inputs_from_state,
"outputs_to_state": self.outputs_to_state,
# This is soft-copied as to not modify the attributes in place
"outputs_to_state": self.outputs_to_state.copy() if self.outputs_to_state else None,
}

if self.outputs_to_state is not None:
Expand All @@ -226,7 +226,11 @@ def to_dict(self) -> Dict[str, Any]:
serialized["outputs_to_state"] = serialized_outputs

if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
# This is soft-copied as to not modify the attributes in place
serialized["outputs_to_string"] = self.outputs_to_string.copy()
serialized["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])
else:
serialized["outputs_to_string"] = None

return {"type": generate_qualified_class_name(type(self)), "data": serialized}

Expand Down
2 changes: 1 addition & 1 deletion haystack/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def to_dict(self) -> Dict[str, Any]:
data["outputs_to_state"] = serialized_outputs

if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
data["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
data["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])

return {"type": generate_qualified_class_name(type(self)), "data": data}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fix the serialization of ComponentTool and Tool when specifying outputs_to_string. Previously an error occurred on deserialization right after serializing if outputs_to_string is not None.
27 changes: 20 additions & 7 deletions test/tools/test_component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def run(self, text: str) -> Dict[str, str]:
return {"reply": f"Hello, {text}!"}


def reply_formatter(input_text: str) -> str:
return f"Formatted reply: {input_text}"


@dataclass
class User:
"""A simple user dataclass."""
Expand Down Expand Up @@ -593,24 +597,33 @@ def test_component_tool_serde(self):
component=SimpleComponent(),
name="simple_tool",
description="A simple tool",
outputs_to_string={"source": "reply", "handler": reply_formatter},
inputs_from_state={"test": "input"},
outputs_to_state={"output": {"source": "out", "handler": output_handler}},
)

# Test serialization
expected_tool_dict = {
"type": "haystack.tools.component_tool.ComponentTool",
"data": {
"component": {"type": "test_component_tool.SimpleComponent", "init_parameters": {}},
"name": "simple_tool",
"description": "A simple tool",
"parameters": None,
"outputs_to_string": {"source": "reply", "handler": "test_component_tool.reply_formatter"},
"inputs_from_state": {"test": "input"},
"outputs_to_state": {"output": {"source": "out", "handler": "test_component_tool.output_handler"}},
},
}
tool_dict = tool.to_dict()
assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool"
assert tool_dict["data"]["name"] == "simple_tool"
assert tool_dict["data"]["description"] == "A simple tool"
assert "component" in tool_dict["data"]
assert tool_dict["data"]["inputs_from_state"] == {"test": "input"}
assert tool_dict["data"]["outputs_to_state"]["output"]["handler"] == "test_component_tool.output_handler"
assert tool_dict == expected_tool_dict

# Test deserialization
new_tool = ComponentTool.from_dict(tool_dict)
new_tool = ComponentTool.from_dict(expected_tool_dict)
assert new_tool.name == tool.name
assert new_tool.description == tool.description
assert new_tool.parameters == tool.parameters
assert new_tool.outputs_to_string == tool.outputs_to_string
assert new_tool.inputs_from_state == tool.inputs_from_state
assert new_tool.outputs_to_state == tool.outputs_to_state
assert isinstance(new_tool._component, SimpleComponent)
Expand Down
17 changes: 13 additions & 4 deletions test/tools/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def get_weather_report(city: str) -> str:
return f"Weather report for {city}: 20°C, sunny"


def format_string(text: str) -> str:
return f"Formatted: {text}"


parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}


Expand Down Expand Up @@ -84,6 +88,8 @@ def test_to_dict(self):
description="Get weather report",
parameters=parameters,
function=get_weather_report,
outputs_to_string={"handler": format_string},
inputs_from_state={"state_key": "tool_input_key"},
outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}},
)

Expand All @@ -94,8 +100,8 @@ def test_to_dict(self):
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"outputs_to_string": None,
"inputs_from_state": None,
"outputs_to_string": {"handler": "test_tool.format_string"},
"inputs_from_state": {"state_key": "tool_input_key"},
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
},
}
Expand All @@ -108,6 +114,8 @@ def test_from_dict(self):
"description": "Get weather report",
"parameters": parameters,
"function": "test_tool.get_weather_report",
"outputs_to_string": {"handler": "test_tool.format_string"},
"inputs_from_state": {"state_key": "tool_input_key"},
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
},
}
Expand All @@ -118,8 +126,9 @@ def test_from_dict(self):
assert tool.description == "Get weather report"
assert tool.parameters == parameters
assert tool.function == get_weather_report
assert tool.outputs_to_state["documents"]["source"] == "docs"
assert tool.outputs_to_state["documents"]["handler"] == get_weather_report
assert tool.outputs_to_string == {"handler": format_string}
assert tool.inputs_from_state == {"state_key": "tool_input_key"}
assert tool.outputs_to_state == {"documents": {"source": "docs", "handler": get_weather_report}}


def test_check_duplicate_tool_names():
Expand Down
Loading