forked from lightspeed-core/lightspeed-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtools.py
More file actions
154 lines (130 loc) · 5.55 KB
/
tools.py
File metadata and controls
154 lines (130 loc) · 5.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""Handler for REST API call to list available tools from MCP servers."""
from typing import Annotated, Any
from fastapi import APIRouter, Depends, HTTPException, Request
from llama_stack_client import APIConnectionError, BadRequestError, AuthenticationError
from llama_stack.core.datatypes import AuthenticationRequiredError
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from models.config import Action
from models.responses import (
ForbiddenResponse,
InternalServerErrorResponse,
ServiceUnavailableResponse,
ToolsResponse,
UnauthorizedResponse,
)
from utils.endpoints import check_configuration_loaded
from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401
from utils.tool_formatter import format_tools_list
from log import get_logger
logger = get_logger(__name__)
router = APIRouter(tags=["tools"])
tools_responses: dict[int | str, dict[str, Any]] = {
200: ToolsResponse.openapi_response(),
401: UnauthorizedResponse.openapi_response(
examples=["missing header", "missing token"]
),
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
503: ServiceUnavailableResponse.openapi_response(),
}
@router.get("/tools", responses=tools_responses)
@authorize(Action.GET_TOOLS)
async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-statements
request: Request,
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
) -> ToolsResponse:
"""
Handle requests to the /tools endpoint.
Process GET requests to the /tools endpoint, returning a consolidated list of
available tools from all configured MCP servers.
Raises:
HTTPException: If unable to connect to the Llama Stack server or if
tool retrieval fails for any reason.
Returns:
ToolsResponse: An object containing the consolidated list of available tools
with metadata including tool name, description, parameters, and server source.
"""
# Used only by the middleware
_ = auth
# Nothing interesting in the request
_ = request
check_configuration_loaded(configuration)
toolgroups_response = []
try:
client = AsyncLlamaStackClientHolder().get_client()
logger.debug("Retrieving tools from all toolgroups")
toolgroups_response = await client.toolgroups.list()
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
response = ServiceUnavailableResponse(backend_name="Llama Stack", cause=str(e))
raise HTTPException(**response.model_dump()) from e
consolidated_tools = []
mcp_server_names = (
{mcp_server.name for mcp_server in configuration.mcp_servers}
if configuration.mcp_servers
else set()
)
for toolgroup in toolgroups_response:
try:
# Get tools for each toolgroup
tools_response = await client.tools.list(toolgroup_id=toolgroup.identifier)
except BadRequestError:
logger.error("Toolgroup %s is not found", toolgroup.identifier)
continue
except (AuthenticationError, AuthenticationRequiredError) as e:
if toolgroup.mcp_endpoint:
await probe_mcp_oauth_and_raise_401(
toolgroup.mcp_endpoint.uri, chain_from=e
)
error_response = UnauthorizedResponse(cause=str(e))
raise HTTPException(**error_response.model_dump()) from e
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
response = ServiceUnavailableResponse(
backend_name="Llama Stack", cause=str(e)
)
raise HTTPException(**response.model_dump()) from e
# Convert tools to dict format
tools_count = 0
server_source = "unknown"
for tool in tools_response:
tool_dict = dict(tool)
# Determine server source based on toolgroup type
if toolgroup.identifier in mcp_server_names:
# This is an MCP server toolgroup
mcp_server = next(
(
s
for s in configuration.mcp_servers
if s.name == toolgroup.identifier
),
None,
)
tool_dict["server_source"] = (
mcp_server.url if mcp_server else toolgroup.identifier
)
else:
# This is a built-in toolgroup
tool_dict["server_source"] = "builtin"
consolidated_tools.append(tool_dict)
tools_count += 1
server_source = tool_dict["server_source"]
logger.debug(
"Retrieved %d tools from toolgroup %s (source: %s)",
tools_count,
toolgroup.identifier,
server_source,
)
logger.info(
"Retrieved total of %d tools (%d from built-in toolgroups, %d from MCP servers)",
len(consolidated_tools),
len([t for t in consolidated_tools if t.get("server_source") == "builtin"]),
len([t for t in consolidated_tools if t.get("server_source") != "builtin"]),
)
# Format tools with structured description parsing
formatted_tools = format_tools_list(consolidated_tools)
return ToolsResponse(tools=formatted_tools)