Skip to content

Commit 3784889

Browse files
sjrldavidsbatista
andauthored
fix: Fix Tool and ComponentTool serialization when specifying outputs_to_string (#9524)
* Fix serialization of outputs_to_string in Tool and ComponentTool * Add reno * Fix mypy, simplify logic * fix pylint * Fix test --------- Co-authored-by: David S. Batista <dsbatista@gmail.com>
1 parent a16ee96 commit 3784889

5 files changed

Lines changed: 46 additions & 16 deletions

File tree

haystack/tools/component_tool.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ def to_dict(self) -> Dict[str, Any]:
206206
"""
207207
serialized_component = component_to_dict(obj=self._component, name=self.name)
208208

209-
serialized = {
209+
serialized: Dict[str, Any] = {
210210
"component": serialized_component,
211211
"name": self.name,
212212
"description": self.description,
213213
"parameters": self._unresolved_parameters,
214-
"outputs_to_string": self.outputs_to_string,
215214
"inputs_from_state": self.inputs_from_state,
216-
"outputs_to_state": self.outputs_to_state,
215+
# This is soft-copied as to not modify the attributes in place
216+
"outputs_to_state": self.outputs_to_state.copy() if self.outputs_to_state else None,
217217
}
218218

219219
if self.outputs_to_state is not None:
@@ -226,7 +226,11 @@ def to_dict(self) -> Dict[str, Any]:
226226
serialized["outputs_to_state"] = serialized_outputs
227227

228228
if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
229-
serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
229+
# This is soft-copied as to not modify the attributes in place
230+
serialized["outputs_to_string"] = self.outputs_to_string.copy()
231+
serialized["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])
232+
else:
233+
serialized["outputs_to_string"] = None
230234

231235
return {"type": generate_qualified_class_name(type(self)), "data": serialized}
232236

haystack/tools/tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def to_dict(self) -> Dict[str, Any]:
122122
data["outputs_to_state"] = serialized_outputs
123123

124124
if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
125-
data["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
125+
data["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])
126126

127127
return {"type": generate_qualified_class_name(type(self)), "data": data}
128128

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Fix the serialization of ComponentTool and Tool when specifying outputs_to_string. Previously an error occurred on deserialization right after serializing if outputs_to_string is not None.

test/tools/test_component_tool.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def run(self, text: str) -> Dict[str, str]:
6060
return {"reply": f"Hello, {text}!"}
6161

6262

63+
def reply_formatter(input_text: str) -> str:
64+
return f"Formatted reply: {input_text}"
65+
66+
6367
@dataclass
6468
class User:
6569
"""A simple user dataclass."""
@@ -593,24 +597,33 @@ def test_component_tool_serde(self):
593597
component=SimpleComponent(),
594598
name="simple_tool",
595599
description="A simple tool",
600+
outputs_to_string={"source": "reply", "handler": reply_formatter},
596601
inputs_from_state={"test": "input"},
597602
outputs_to_state={"output": {"source": "out", "handler": output_handler}},
598603
)
599604

600605
# Test serialization
606+
expected_tool_dict = {
607+
"type": "haystack.tools.component_tool.ComponentTool",
608+
"data": {
609+
"component": {"type": "test_component_tool.SimpleComponent", "init_parameters": {}},
610+
"name": "simple_tool",
611+
"description": "A simple tool",
612+
"parameters": None,
613+
"outputs_to_string": {"source": "reply", "handler": "test_component_tool.reply_formatter"},
614+
"inputs_from_state": {"test": "input"},
615+
"outputs_to_state": {"output": {"source": "out", "handler": "test_component_tool.output_handler"}},
616+
},
617+
}
601618
tool_dict = tool.to_dict()
602-
assert tool_dict["type"] == "haystack.tools.component_tool.ComponentTool"
603-
assert tool_dict["data"]["name"] == "simple_tool"
604-
assert tool_dict["data"]["description"] == "A simple tool"
605-
assert "component" in tool_dict["data"]
606-
assert tool_dict["data"]["inputs_from_state"] == {"test": "input"}
607-
assert tool_dict["data"]["outputs_to_state"]["output"]["handler"] == "test_component_tool.output_handler"
619+
assert tool_dict == expected_tool_dict
608620

609621
# Test deserialization
610-
new_tool = ComponentTool.from_dict(tool_dict)
622+
new_tool = ComponentTool.from_dict(expected_tool_dict)
611623
assert new_tool.name == tool.name
612624
assert new_tool.description == tool.description
613625
assert new_tool.parameters == tool.parameters
626+
assert new_tool.outputs_to_string == tool.outputs_to_string
614627
assert new_tool.inputs_from_state == tool.inputs_from_state
615628
assert new_tool.outputs_to_state == tool.outputs_to_state
616629
assert isinstance(new_tool._component, SimpleComponent)

test/tools/test_tool.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ def get_weather_report(city: str) -> str:
1313
return f"Weather report for {city}: 20°C, sunny"
1414

1515

16+
def format_string(text: str) -> str:
17+
return f"Formatted: {text}"
18+
19+
1620
parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
1721

1822

@@ -84,6 +88,8 @@ def test_to_dict(self):
8488
description="Get weather report",
8589
parameters=parameters,
8690
function=get_weather_report,
91+
outputs_to_string={"handler": format_string},
92+
inputs_from_state={"state_key": "tool_input_key"},
8793
outputs_to_state={"documents": {"handler": get_weather_report, "source": "docs"}},
8894
)
8995

@@ -94,8 +100,8 @@ def test_to_dict(self):
94100
"description": "Get weather report",
95101
"parameters": parameters,
96102
"function": "test_tool.get_weather_report",
97-
"outputs_to_string": None,
98-
"inputs_from_state": None,
103+
"outputs_to_string": {"handler": "test_tool.format_string"},
104+
"inputs_from_state": {"state_key": "tool_input_key"},
99105
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
100106
},
101107
}
@@ -108,6 +114,8 @@ def test_from_dict(self):
108114
"description": "Get weather report",
109115
"parameters": parameters,
110116
"function": "test_tool.get_weather_report",
117+
"outputs_to_string": {"handler": "test_tool.format_string"},
118+
"inputs_from_state": {"state_key": "tool_input_key"},
111119
"outputs_to_state": {"documents": {"source": "docs", "handler": "test_tool.get_weather_report"}},
112120
},
113121
}
@@ -118,8 +126,9 @@ def test_from_dict(self):
118126
assert tool.description == "Get weather report"
119127
assert tool.parameters == parameters
120128
assert tool.function == get_weather_report
121-
assert tool.outputs_to_state["documents"]["source"] == "docs"
122-
assert tool.outputs_to_state["documents"]["handler"] == get_weather_report
129+
assert tool.outputs_to_string == {"handler": format_string}
130+
assert tool.inputs_from_state == {"state_key": "tool_input_key"}
131+
assert tool.outputs_to_state == {"documents": {"source": "docs", "handler": get_weather_report}}
123132

124133

125134
def test_check_duplicate_tool_names():

0 commit comments

Comments
 (0)