Skip to content

Commit 1bc49be

Browse files
authored
Handle name collisions between local and remote tools (#52)
1 parent 644fdc4 commit 1bc49be

File tree

12 files changed

+274
-133
lines changed

12 files changed

+274
-133
lines changed

.basedpyright/baseline.json

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -201,30 +201,6 @@
201201
}
202202
],
203203
"./splunklib/ai/tools.py": [
204-
{
205-
"code": "reportUnusedImport",
206-
"range": {
207-
"startColumn": 7,
208-
"endColumn": 22,
209-
"lineCount": 1
210-
}
211-
},
212-
{
213-
"code": "reportPrivateUsage",
214-
"range": {
215-
"startColumn": 43,
216-
"endColumn": 75,
217-
"lineCount": 1
218-
}
219-
},
220-
{
221-
"code": "reportUnannotatedClassAttribute",
222-
"range": {
223-
"startColumn": 13,
224-
"endColumn": 27,
225-
"lineCount": 1
226-
}
227-
},
228204
{
229205
"code": "reportUnknownVariableType",
230206
"range": {
@@ -30169,22 +30145,6 @@
3016930145
}
3017030146
],
3017130147
"./tests/integration/ai/test_agent_mcp_tools.py": [
30172-
{
30173-
"code": "reportUnusedImport",
30174-
"range": {
30175-
"startColumn": 21,
30176-
"endColumn": 30,
30177-
"lineCount": 1
30178-
}
30179-
},
30180-
{
30181-
"code": "reportPrivateUsage",
30182-
"range": {
30183-
"startColumn": 4,
30184-
"endColumn": 24,
30185-
"lineCount": 1
30186-
}
30187-
},
3018830148
{
3018930149
"code": "reportUnknownArgumentType",
3019030150
"range": {

examples/ai_custom_search_app/metadata/local.meta

Whitespace-only changes.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,15 @@ reportUnusedCallResult = false
8686
[tool.ruff.lint]
8787
fixable = ["ALL"]
8888
select = [
89-
"ANN", # flake8 type annotations
89+
"ANN", # flake-8-annotations
9090
"C4", # comprehensions
9191
"DOC", # pydocstyle
9292
"E", # pycodestyle
9393
"F", # pyflakes
9494
"I", # isort
95-
"UP", # pyupgrade
95+
"PT", # flake-8-pytest-rules
9696
"RUF", # ruff-specific rules
97+
"UP", # pyupgrade
9798
]
9899

99100
[tool.ruff.lint.isort]

splunklib/ai/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from splunklib.ai.tool_filtering import ToolFilters, filter_tools
3030
from splunklib.ai.tools import (
3131
Tool,
32+
ToolType,
3233
build_local_tools_path,
3334
connect_local_mcp,
3435
connect_remote_mcp,
@@ -171,7 +172,11 @@ async def _start_agent(self) -> AsyncGenerator[Self]:
171172
)
172173
self.logger.debug("Loading local tools")
173174
local_tools = await load_mcp_tools(
174-
local_session, "local", app_id, self.trace_id, self._service
175+
local_session,
176+
ToolType.LOCAL,
177+
app_id,
178+
self.trace_id,
179+
self._service,
175180
)
176181
self.logger.debug(f"Local tools loaded; {local_tools=}")
177182
tools.extend(local_tools)
@@ -188,7 +193,7 @@ async def _start_agent(self) -> AsyncGenerator[Self]:
188193
self.logger.debug("Loading remote tools - MCP Server available")
189194
remote_tools = await load_mcp_tools(
190195
remote_session,
191-
"remote",
196+
ToolType.REMOTE,
192197
app_id,
193198
self.trace_id,
194199
self._service,

splunklib/ai/core/backend_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
def get_backend() -> Backend:
1919
"""Get a backend instance."""
20-
20+
# Lazy import to avoid circular dependency hell between LangChain and SDK
2121
from splunklib.ai.engines.langchain import langchain_backend_factory
2222

2323
# NOTE: For now we're just using the langchain backend implementation

splunklib/ai/engines/langchain.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
tool_middleware,
8888
)
8989
from splunklib.ai.model import OpenAIModel, PredefinedModel
90-
from splunklib.ai.tools import Tool, ToolException
90+
from splunklib.ai.tools import Tool, ToolException, ToolType
9191

9292
# Represents a prefix reserved only for internal use.
9393
# No user-visible tool or subagent name can be prefixed with it.
@@ -102,6 +102,10 @@
102102
# backward compatibility measure - we're free to use any prefixed tool name.
103103
CONFLICTING_TOOL_PREFIX = f"{RESERVED_LC_TOOL_PREFIX}tool-"
104104

105+
# Prepended to a local tool name when passed to LangChain to both avoid name conflicts
106+
# and to allow recovering tool type during LC -> SDK conversion
107+
LOCAL_TOOL_PREFIX = f"{RESERVED_LC_TOOL_PREFIX}local-"
108+
105109
AGENT_AS_TOOLS_PROMPT = f"""
106110
You are provided with Agents.
107111
Agents are more advanced TOOLS, which start with "{AGENT_PREFIX}" prefix.
@@ -242,16 +246,25 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
242246
)
243247

244248

249+
def _prepare_langchain_tools(agent_tools: Sequence[Tool]) -> list[BaseTool]:
250+
"""We prefix every local tool name."""
251+
tools = list[BaseTool]()
252+
for a_tool in agent_tools:
253+
tools.append(_create_langchain_tool(a_tool))
254+
255+
return tools
256+
257+
245258
@final
246259
class LangChainBackend(Backend):
247260
@override
248261
async def create_agent(
249262
self,
250263
agent: BaseAgent[OutputT],
251264
) -> AgentImpl[OutputT]:
252-
system_prompt = agent.system_prompt
253-
tools = [_create_langchain_tool(t) for t in agent.tools]
265+
tools = _prepare_langchain_tools(agent.tools)
254266

267+
system_prompt = agent.system_prompt
255268
if agent.agents:
256269
seen_names: set[str] = set()
257270
for subagent in agent.agents:
@@ -466,7 +479,8 @@ def _convert_tool_request_to_lc(
466479

467480

468481
def _convert_subagent_request_to_lc(
469-
request: SubagentRequest, original_request: LC_ToolCallRequest
482+
request: SubagentRequest,
483+
original_request: LC_ToolCallRequest,
470484
) -> LC_ToolCallRequest:
471485
return original_request.override(
472486
tool_call=_map_tool_call_to_langchain(request.call),
@@ -475,7 +489,8 @@ def _convert_subagent_request_to_lc(
475489

476490

477491
def _convert_model_request_to_lc(
478-
request: ModelRequest, original_request: LC_ModelRequest
492+
request: ModelRequest,
493+
original_request: LC_ModelRequest,
479494
) -> LC_ModelRequest:
480495
return original_request.override(
481496
system_message=LC_SystemMessage(content=request.system_message),
@@ -504,7 +519,7 @@ def _convert_tool_message_to_lc(
504519
case SubagentMessage():
505520
name = _normalize_agent_name(message.name)
506521
case ToolMessage():
507-
name = _normalize_tool_name(message.name)
522+
name = _normalize_tool_name(message.name, message.type)
508523

509524
return LC_ToolMessage(
510525
name=name,
@@ -515,11 +530,10 @@ def _convert_tool_message_to_lc(
515530

516531

517532
def _convert_tool_response_to_lc(
518-
response: ToolResponse,
519-
call: ToolCall,
533+
response: ToolResponse, call: ToolCall
520534
) -> LC_ToolMessage:
521535
return LC_ToolMessage(
522-
name=_normalize_tool_name(call.name),
536+
name=_normalize_tool_name(call.name, call.type),
523537
content=response.content,
524538
tool_call_id=call.id,
525539
status=response.status,
@@ -554,11 +568,18 @@ def _convert_tool_message_from_lc(
554568
assert message.name is not None, (
555569
"LangChain responded with a nameless tool call"
556570
)
571+
572+
tool_type: ToolType = (
573+
ToolType.LOCAL
574+
if message.name.startswith(LOCAL_TOOL_PREFIX)
575+
else ToolType.REMOTE
576+
)
557577
return ToolMessage(
558578
name=_denormalize_tool_name(message.name),
559579
content=message.content.__str__(),
560580
call_id=message.tool_call_id,
561581
status=message.status,
582+
type=tool_type,
562583
)
563584
case LC_Command():
564585
# NOTE: for now the command is not implemented
@@ -668,7 +689,7 @@ async def _tool_call(**kwargs: dict[str, Any]) -> dict[str, Any] | list[str]:
668689
except ToolException as e:
669690
raise LC_ToolException(*e.args) from e
670691
except LC_ToolException:
671-
assert False, (
692+
assert False, ( # noqa: PT015
672693
"ToolException from LangChain should not be raised in tool.func"
673694
)
674695

@@ -687,7 +708,7 @@ async def _tool_call(**kwargs: dict[str, Any]) -> dict[str, Any] | list[str]:
687708
return result.content
688709

689710
return StructuredTool(
690-
name=_normalize_tool_name(tool.name),
711+
name=_normalize_tool_name(tool.name, tool.type),
691712
description=tool.description,
692713
args_schema=tool.input_schema,
693714
coroutine=_tool_call,
@@ -709,14 +730,24 @@ def _denormalize_agent_name(name: str) -> str:
709730
return name.removeprefix(AGENT_PREFIX)
710731

711732

712-
def _normalize_tool_name(name: str) -> str:
733+
def _normalize_tool_name(name: str, tool_type: ToolType) -> str:
734+
if tool_type == ToolType.LOCAL:
735+
return LOCAL_TOOL_PREFIX + name
736+
713737
if name.startswith(RESERVED_LC_TOOL_PREFIX):
714-
return f"{CONFLICTING_TOOL_PREFIX}{name}"
738+
# Tool name contains our reserved prefix, see comment
739+
# on CONFLICTING_TOOL_PREFIX for more details
740+
return CONFLICTING_TOOL_PREFIX + name
741+
715742
return name
716743

717744

718745
def _denormalize_tool_name(name: str) -> str:
719-
return name.removeprefix(CONFLICTING_TOOL_PREFIX)
746+
if name.startswith(RESERVED_LC_TOOL_PREFIX):
747+
assert "-" in name, "Invalid prefix in tool name"
748+
_prefix, name = name.split("-", maxsplit=1)
749+
750+
return name
720751

721752

722753
def _agent_as_tool(agent: BaseAgent[OutputT]) -> StructuredTool:
@@ -757,17 +788,22 @@ async def _run(**kwargs: dict[str, Any]) -> OutputT | str:
757788

758789

759790
def _map_tool_call_from_langchain(tool_call: LC_ToolCall) -> ToolCall | SubagentCall:
760-
if tool_call["name"].startswith(AGENT_PREFIX):
791+
name = tool_call["name"]
792+
if name.startswith(AGENT_PREFIX):
761793
return SubagentCall(
762-
name=_denormalize_agent_name(tool_call["name"]),
794+
name=_denormalize_agent_name(name),
763795
args=tool_call["args"],
764796
id=tool_call["id"],
765797
)
766798

799+
tool_type: ToolType = (
800+
ToolType.LOCAL if name.startswith(LOCAL_TOOL_PREFIX) else ToolType.REMOTE
801+
)
767802
return ToolCall(
768-
name=_denormalize_tool_name(tool_call["name"]),
803+
name=_denormalize_tool_name(name),
769804
args=tool_call["args"],
770805
id=tool_call["id"],
806+
type=tool_type,
771807
)
772808

773809

@@ -776,13 +812,9 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
776812
case SubagentCall():
777813
name = _normalize_agent_name(call.name)
778814
case ToolCall():
779-
name = _normalize_tool_name(call.name)
815+
name = _normalize_tool_name(call.name, call.type)
780816

781-
return LC_ToolCall(
782-
name=name,
783-
args=call.args,
784-
id=call.id,
785-
)
817+
return LC_ToolCall(id=call.id, name=name, args=call.args)
786818

787819

788820
def _map_message_from_langchain(message: LC_BaseMessage) -> BaseMessage:
@@ -806,7 +838,7 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
806838
match message:
807839
case AIMessage():
808840
lc_message = LC_AIMessage(content=message.content)
809-
# this field can't be set via constructor
841+
# This field can't be set via constructor
810842
lc_message.tool_calls = [
811843
_map_tool_call_to_langchain(c) for c in message.calls
812844
]

splunklib/ai/messages.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414

15-
1615
from collections.abc import Sequence
1716
from dataclasses import dataclass, field
1817
from typing import Any, Generic, Literal, TypeVar
1918

2019
from pydantic import BaseModel
2120

22-
OutputT = TypeVar("OutputT", default=None, covariant=True, bound=BaseModel | None)
21+
from splunklib.ai.tools import ToolType
2322

2423

2524
@dataclass(frozen=True)
2625
class ToolCall:
2726
name: str
2827
args: dict[str, Any]
2928
id: str | None # TODO: can be None?
29+
type: ToolType
3030

3131

3232
@dataclass(frozen=True)
@@ -41,7 +41,7 @@ class BaseMessage:
4141
role: str = ""
4242
content: str = field(default="")
4343

44-
def __post_init__(self):
44+
def __post_init__(self) -> None:
4545
if type(self) is BaseMessage:
4646
raise TypeError(
4747
"BaseMessage is an abstract class and cannot be instantiated"
@@ -79,14 +79,15 @@ class AIMessage(BaseMessage):
7979

8080
@dataclass(frozen=True)
8181
class ToolMessage(BaseMessage):
82-
"""
83-
ToolMessage represents a response of a tool call
84-
"""
82+
"""ToolMessage represents a response of a tool call"""
83+
84+
# TODO: See if we can remove the defaults - they should always be populated manually
8585

8686
role: Literal["tool"] = "tool"
8787
name: str = field(default="")
8888
call_id: str = field(default="")
8989
status: Literal["success", "error"] = "success"
90+
type: ToolType = ToolType.LOCAL
9091

9192

9293
@dataclass(frozen=True)
@@ -110,6 +111,9 @@ class SubagentMessage(BaseMessage):
110111
status: Literal["success", "error"] = "success"
111112

112113

114+
OutputT = TypeVar("OutputT", default=None, covariant=True, bound=BaseModel | None)
115+
116+
113117
@dataclass(frozen=True)
114118
class AgentResponse(Generic[OutputT]):
115119
# in case output_schema is provided, this will hold the parsed structured output

0 commit comments

Comments
 (0)