Skip to content

Commit 50c2d2c

Browse files
mateusz834szykol
authored andcommitted
Add ToolRegistry (#4)
This change implements a ToolRegistry, containing a tool decorator, that is used to register MCP tools. The decorator infers a JSON schema from the input and output params of the decorated function. Additionally to support Splunk-specific needs, this change introduces a ToolContext, which tools can accept to gain access to additional functionalities. For now the ToolContext exposes a connection to the Splunk REST API, but in future we will add more functionalities there like: logging, tool cancellation, tool notifications and so on.
1 parent 6d641a3 commit 50c2d2c

File tree

9 files changed

+1013
-0
lines changed

9 files changed

+1013
-0
lines changed

splunklib/ai/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#
2+
# Copyright © 2011-2025 Splunk, Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"): you may
5+
# not use this file except in compliance with the License. You may obtain
6+
# a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
# License for the specific language governing permissions and limitations
14+
# under the License.

splunklib/ai/registry.py

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
#
2+
# Copyright © 2011-2025 Splunk, Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"): you may
5+
# not use this file except in compliance with the License. You may obtain
6+
# a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
# License for the specific language governing permissions and limitations
14+
# under the License.
15+
import asyncio
16+
import inspect
17+
from dataclasses import asdict, dataclass
18+
from typing import Any, Callable, Generic, ParamSpec, TypeVar, get_type_hints
19+
20+
import mcp.types as types
21+
from mcp.server.lowlevel import Server
22+
from pydantic import TypeAdapter
23+
24+
from splunklib.binding import _spliturl
25+
from splunklib.client import Service, connect
26+
27+
28+
class ToolContext:
29+
"""
30+
ToolContext provides a way to interact with the tool execution context.
31+
A new instance is automatically injected as a function parameter when a
32+
relevant type hint is detected.
33+
"""
34+
35+
_management_url: str | None = None
36+
_management_token: str | None = None
37+
_service: Service | None = None
38+
39+
@property
40+
def service(self) -> Service:
41+
"""
42+
returns a connected :class:`Service` object to the Splunk instance,
43+
that executed the tool.
44+
"""
45+
if self._service is not None:
46+
return self._service
47+
48+
assert all((self._management_url, self._management_token)), (
49+
"Invalid tool invocation, missing management_url and/or management_token"
50+
)
51+
52+
scheme, host, port, path = _spliturl(self._management_url)
53+
s = connect(
54+
scheme=scheme,
55+
host=host,
56+
port=port,
57+
path=path,
58+
token=self._management_token,
59+
autologin=True,
60+
)
61+
self._service = s
62+
return s
63+
64+
65+
_T = TypeVar("_T", default=Any)
66+
67+
68+
@dataclass
69+
class _WrappedResult(Generic[_T]):
70+
result: _T
71+
72+
73+
_P = ParamSpec("_P")
74+
_R = TypeVar("_R")
75+
76+
77+
class ToolRegistryRuntimeError(RuntimeError):
78+
"""Raised when a tool registry operation fails."""
79+
80+
pass
81+
82+
83+
class ToolRegistry:
84+
_server: Server
85+
_tools: list[types.Tool]
86+
_tools_func: dict[str, Callable]
87+
_tools_wrapped_result: dict[str, bool]
88+
_executing: bool = False
89+
90+
def __init__(self) -> None:
91+
self._server = Server("Tool Registry")
92+
self._tools = []
93+
self._tools_func = {}
94+
self._tools_wrapped_result = {}
95+
self._register_handlers()
96+
97+
def _register_handlers(self) -> None:
98+
@self._server.list_tools()
99+
async def _() -> list[types.Tool]:
100+
return self._list_tools()
101+
102+
@self._server.call_tool(validate_input=True)
103+
async def _(name: str, arguments: dict[str, Any]) -> types.CallToolResult:
104+
return self._call_tool(name, arguments)
105+
106+
def _list_tools(self) -> list[types.Tool]:
107+
return self._tools
108+
109+
def _call_tool(self, name: str, arguments: dict[str, Any]) -> types.CallToolResult:
110+
func = self._tools_func.get(name)
111+
if func is None:
112+
raise ValueError(f"Tool {name} does not exist")
113+
114+
ctx = ToolContext()
115+
meta = self._server.request_context.meta
116+
if meta is not None:
117+
splunk_meta = meta.model_dump().get("splunk")
118+
if splunk_meta is not None:
119+
ctx._management_url = splunk_meta.get("management_url")
120+
ctx._management_token = splunk_meta.get("management_token")
121+
122+
for k in func.__annotations__:
123+
if func.__annotations__[k] == ToolContext:
124+
assert arguments.get(k) is None, (
125+
"Improper input schema was generated or schema verification is malfunctioning"
126+
)
127+
arguments[k] = ctx
128+
129+
res = func(**arguments)
130+
131+
if self._tools_wrapped_result.get(name):
132+
res = _WrappedResult(res)
133+
134+
return types.CallToolResult(
135+
structuredContent=asdict(res),
136+
content=[],
137+
)
138+
139+
def _input_schema(self, func: Callable[_P, _R]) -> dict[str, Any]:
140+
"""
141+
Generates a input schema for the provided func, skips arguments of type: `ToolContext`.
142+
"""
143+
144+
ctxs: list[str] = []
145+
for k in func.__annotations__:
146+
if func.__annotations__[k] == ToolContext:
147+
ctxs.append(k)
148+
149+
input_schema = TypeAdapter(_drop_type_annotations_of(func, ctxs)).json_schema()
150+
151+
# _drop_type_annotations_of removed the type annotation to prevent json_schema()
152+
# from attempting to infer type information for ToolContext (which would fail).
153+
# However, ToolContext fields still appear in the properties and required
154+
# fields of the schema (we only made sure that no type information was generated
155+
# in the schema, that corresponds to the ToolContext), so we need to remove those
156+
# references here as well.
157+
for ctx in ctxs:
158+
props = input_schema.get("properties", {})
159+
props.pop(ctx)
160+
161+
if ctx in input_schema.get("required", []):
162+
input_schema["required"].remove(ctx)
163+
if not input_schema["required"]:
164+
input_schema.pop("required")
165+
166+
return input_schema
167+
168+
def _output_schema(self, func: Callable[_P, _R]) -> tuple[dict[str, Any], bool]:
169+
"""
170+
Generates a output schema for the provided func, if necessary wraps the
171+
output type with :class:`_WrappedResult`.
172+
173+
Returns an output schema and a boolean that signals whether the result
174+
needs to be wrapped.
175+
"""
176+
177+
sig = inspect.signature(func)
178+
output_schema = TypeAdapter(sig.return_annotation).json_schema(
179+
mode="serialization"
180+
)
181+
182+
# Since all structured results must be an object in MCP,
183+
# if the result type of the provided function is not an object,
184+
# then wrap it in a _WrappedResult to make it a object.
185+
is_object = (
186+
output_schema.get("type") == "object" or "properties" in output_schema
187+
)
188+
if not is_object:
189+
output_schema = TypeAdapter(
190+
_WrappedResult[
191+
get_type_hints(func, include_extras=True).get(
192+
"return", sig.return_annotation
193+
)
194+
]
195+
).json_schema(mode="serialization")
196+
return output_schema, True
197+
return output_schema, False
198+
199+
def tool(
200+
self,
201+
name: str | None = None,
202+
description: str | None = None,
203+
title: str | None = None,
204+
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
205+
"""
206+
Decorator that registers a function with the ToolRegistry.
207+
208+
The decorator automatically infers a JSON Schema from the function's
209+
type hints, using them to define the tool's expected input and output
210+
structure.
211+
212+
Functions may optionally accept a :class:`ToolContext` parameter, which provides
213+
access to additional tool-related functionality.
214+
215+
:param name: An optional name of the tool.
216+
If omitted, the function's name is used.
217+
:param description: An optional human-readable description of the tool.
218+
If omitted, the function's docstring is used.
219+
220+
"""
221+
222+
def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]:
223+
nonlocal description
224+
if description is None:
225+
description = func.__doc__
226+
227+
nonlocal name
228+
if name is None:
229+
name = func.__name__
230+
231+
if self._executing:
232+
raise ToolRegistryRuntimeError(
233+
"ToolRegistry is already running, cannot define new tools"
234+
)
235+
236+
if self._tools_func.get(name) is not None:
237+
raise ToolRegistryRuntimeError(f"Tool {name} already defined")
238+
239+
input_schema = self._input_schema(func)
240+
output_schema, wrapped_output = self._output_schema(func)
241+
242+
self._tools.append(
243+
types.Tool(
244+
name=name,
245+
title=title,
246+
description=description,
247+
inputSchema=input_schema,
248+
outputSchema=output_schema,
249+
)
250+
)
251+
self._tools_func[name] = func
252+
self._tools_wrapped_result[name] = wrapped_output
253+
254+
return func
255+
256+
return wrapper
257+
258+
def run(self) -> None:
259+
async def run() -> None:
260+
import mcp.server.stdio
261+
from mcp.server.lowlevel import NotificationOptions
262+
from mcp.server.models import InitializationOptions
263+
264+
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
265+
await self._server.run(
266+
read_stream,
267+
write_stream,
268+
InitializationOptions(
269+
server_name="Utility App - Tool Registry",
270+
server_version="",
271+
capabilities=self._server.get_capabilities(
272+
notification_options=NotificationOptions(),
273+
experimental_capabilities={},
274+
),
275+
),
276+
)
277+
278+
self._executing = True
279+
asyncio.run(run())
280+
281+
282+
def _drop_type_annotations_of(
283+
fn: Callable[..., Any], exclude_params: list[str]
284+
) -> Callable[..., Any]:
285+
"""
286+
Creates a new function, that has the type information elided for each
287+
param in `exclude_params`.
288+
"""
289+
import types
290+
291+
original_annotations = getattr(fn, "__annotations__", {})
292+
new_annotations = {
293+
k: v for k, v in original_annotations.items() if k not in exclude_params
294+
}
295+
296+
new_func = types.FunctionType(
297+
fn.__code__,
298+
fn.__globals__,
299+
fn.__name__,
300+
fn.__defaults__,
301+
fn.__closure__,
302+
)
303+
new_func.__dict__.update(fn.__dict__)
304+
new_func.__module__ = fn.__module__
305+
new_func.__qualname__ = getattr(fn, "__qualname__", fn.__name__) # ty: ignore[unresolved-attribute]
306+
new_func.__annotations__ = new_annotations
307+
308+
return new_func

0 commit comments

Comments
 (0)