Skip to content

Commit e7346bd

Browse files
vblagojesjrl
andauthored
feat: Add state-based configuration support to MCPToolset (#2689)
* Add state-based configuration support to MCPToolset * Some final touches * Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * Update integrations/mcp/src/haystack_integrations/tools/mcp/mcp_toolset.py Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com> * PR touches * Add MCP tool/Agent state io integration test * Test collision fixes * PR feedback - mpangrazzi --------- Co-authored-by: Sebastian Husch Lee <10526848+sjrl@users.noreply.github.com>
1 parent 8e48a7b commit e7346bd

5 files changed

Lines changed: 632 additions & 28 deletions

File tree

integrations/mcp/src/haystack_integrations/tools/mcp/mcp_tool.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import concurrent.futures
7+
import json
78
import threading
89
import warnings
910
from abc import ABC, abstractmethod
@@ -1048,12 +1049,13 @@ def _connect_and_initialize(self, tool_name: str) -> types.Tool:
10481049

10491050
return tool
10501051

1051-
def _invoke_tool(self, **kwargs: Any) -> str:
1052+
def _invoke_tool(self, **kwargs: Any) -> str | dict[str, Any]:
10521053
"""
10531054
Synchronous tool invocation.
10541055
10551056
:param kwargs: Arguments to pass to the tool
1056-
:returns: JSON string representation of the tool invocation result
1057+
:returns: JSON string or dictionary representation of the tool invocation result.
1058+
Returns a dictionary when outputs_to_state is configured to enable state updates.
10571059
"""
10581060
logger.debug(f"TOOL: Invoking tool '{self.name}' with args: {kwargs}")
10591061
try:
@@ -1070,6 +1072,26 @@ async def invoke():
10701072
logger.debug(f"TOOL: About to run invoke for '{self.name}'")
10711073
result = AsyncExecutor.get_instance().run(invoke(), timeout=self._invocation_timeout)
10721074
logger.debug(f"TOOL: Invoke complete for '{self.name}', result type: {type(result)}")
1075+
1076+
# Parse JSON to dict only when outputs_to_state is configured.
1077+
# ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise.
1078+
if self.outputs_to_state:
1079+
parsed = json.loads(result)
1080+
1081+
# Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc.
1082+
# Parse only first TextContent block (ToolInvoker requires dict, not list).
1083+
content = parsed.get("content", [])
1084+
for block in content:
1085+
if isinstance(block, dict) and block.get("type") == "text":
1086+
text = block.get("text", "")
1087+
try:
1088+
return json.loads(text)
1089+
except (json.JSONDecodeError, TypeError):
1090+
return text
1091+
1092+
# No TextContent found, return full parsed response as fallback
1093+
return parsed
1094+
10731095
return result
10741096
except (MCPError, TimeoutError) as e:
10751097
logger.debug(f"TOOL: Known error during invoke of '{self.name}': {e!s}")
@@ -1081,19 +1103,41 @@ async def invoke():
10811103
message = f"Failed to invoke tool '{self.name}' with args: {kwargs} , got error: {e!s}"
10821104
raise MCPInvocationError(message, self.name, kwargs) from e
10831105

1084-
async def ainvoke(self, **kwargs: Any) -> str:
1106+
async def ainvoke(self, **kwargs: Any) -> str | dict[str, Any]:
10851107
"""
10861108
Asynchronous tool invocation.
10871109
10881110
:param kwargs: Arguments to pass to the tool
1089-
:returns: JSON string representation of the tool invocation result
1111+
:returns: JSON string or dictionary representation of the tool invocation result.
1112+
Returns a dictionary when outputs_to_state is configured to enable state updates.
10901113
:raises MCPInvocationError: If the tool invocation fails
10911114
:raises TimeoutError: If the operation times out
10921115
"""
10931116
try:
10941117
self.warm_up()
10951118
client = cast(MCPClient, self._client)
1096-
return await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
1119+
result = await asyncio.wait_for(client.call_tool(self.name, kwargs), timeout=self._invocation_timeout)
1120+
1121+
# Parse JSON to dict only when outputs_to_state is configured.
1122+
# ToolInvoker requires dict for _merge_tool_outputs(); ToolCallResult.result expects str otherwise.
1123+
if self.outputs_to_state:
1124+
parsed = json.loads(result)
1125+
1126+
# Per MCP spec, content[] may contain TextContent, ImageContent, AudioContent, etc.
1127+
# Parse only first TextContent block (ToolInvoker requires dict, not list).
1128+
content = parsed.get("content", [])
1129+
for block in content:
1130+
if isinstance(block, dict) and block.get("type") == "text":
1131+
text = block.get("text", "")
1132+
try:
1133+
return json.loads(text)
1134+
except (json.JSONDecodeError, TypeError):
1135+
return text
1136+
1137+
# No TextContent found, return full parsed response as fallback
1138+
return parsed
1139+
1140+
return result
10971141
except asyncio.TimeoutError as e:
10981142
message = f"Tool invocation timed out after {self._invocation_timeout} seconds"
10991143
raise TimeoutError(message) from e

0 commit comments

Comments
 (0)