Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable
from typing import Any
from urllib.parse import urlparse

Expand All @@ -13,6 +14,7 @@

from .mcp_tool import (
AsyncExecutor,
MCPClient,
MCPConnectionError,
MCPServerInfo,
MCPToolNotFoundError,
Expand Down Expand Up @@ -145,13 +147,19 @@ def __init__(
)

# This is a factory that creates the invocation function for the Tool
def create_invoke_tool(mcp_client, tool_name, tool_timeout):
def create_invoke_tool(
owner_toolset: "MCPToolset",
mcp_client: MCPClient,
tool_name: str,
tool_timeout: float,
) -> Callable[..., Any]:
"""Return a closure that keeps a strong reference to *owner_toolset* alive."""

def invoke_tool(**kwargs) -> Any:
"""Invoke a tool using the existing client and AsyncExecutor."""
result = AsyncExecutor.get_instance().run(
_ = owner_toolset # strong reference so GC can't collect the toolset too early
return AsyncExecutor.get_instance().run(
mcp_client.call_tool(tool_name, kwargs), timeout=tool_timeout
)
return result

return invoke_tool

Expand All @@ -170,7 +178,7 @@ def invoke_tool(**kwargs) -> Any:
name=tool_info.name,
description=tool_info.description,
parameters=tool_info.inputSchema,
function=create_invoke_tool(client, tool_info.name, self.invocation_timeout),
function=create_invoke_tool(self, client, tool_info.name, self.invocation_timeout),
)
haystack_tools.append(tool)

Expand Down