8787 tool_middleware ,
8888)
8989from splunklib .ai .model import OpenAIModel , PredefinedModel
90- from splunklib .ai .tools import Tool , ToolException
90+ from splunklib .ai .tools import Tool , ToolException , ToolType
9191
9292# Represents a prefix reserved only for internal use.
9393# No user-visible tool or subagent name can be prefixed with it.
102102# backward compatibility measure - we're free to use any prefixed tool name.
103103CONFLICTING_TOOL_PREFIX = f"{ RESERVED_LC_TOOL_PREFIX } tool-"
104104
105+ # Prepended to a local tool name when passed to LangChain to both avoid name conflicts
106+ # and to allow recovering tool type during LC -> SDK conversion
107+ LOCAL_TOOL_PREFIX = f"{ RESERVED_LC_TOOL_PREFIX } local-"
108+
105109AGENT_AS_TOOLS_PROMPT = f"""
106110You are provided with Agents.
107111Agents are more advanced TOOLS, which start with "{ AGENT_PREFIX } " prefix.
@@ -242,16 +246,25 @@ async def invoke_agent(req: AgentRequest) -> AgentResponse[Any | None]:
242246 )
243247
244248
249+ def _prepare_langchain_tools (agent_tools : Sequence [Tool ]) -> list [BaseTool ]:
250+ """We prefix every local tool name."""
251+ tools = list [BaseTool ]()
252+ for a_tool in agent_tools :
253+ tools .append (_create_langchain_tool (a_tool ))
254+
255+ return tools
256+
257+
245258@final
246259class LangChainBackend (Backend ):
247260 @override
248261 async def create_agent (
249262 self ,
250263 agent : BaseAgent [OutputT ],
251264 ) -> AgentImpl [OutputT ]:
252- system_prompt = agent .system_prompt
253- tools = [_create_langchain_tool (t ) for t in agent .tools ]
265+ tools = _prepare_langchain_tools (agent .tools )
254266
267+ system_prompt = agent .system_prompt
255268 if agent .agents :
256269 seen_names : set [str ] = set ()
257270 for subagent in agent .agents :
@@ -466,7 +479,8 @@ def _convert_tool_request_to_lc(
466479
467480
468481def _convert_subagent_request_to_lc (
469- request : SubagentRequest , original_request : LC_ToolCallRequest
482+ request : SubagentRequest ,
483+ original_request : LC_ToolCallRequest ,
470484) -> LC_ToolCallRequest :
471485 return original_request .override (
472486 tool_call = _map_tool_call_to_langchain (request .call ),
@@ -475,7 +489,8 @@ def _convert_subagent_request_to_lc(
475489
476490
477491def _convert_model_request_to_lc (
478- request : ModelRequest , original_request : LC_ModelRequest
492+ request : ModelRequest ,
493+ original_request : LC_ModelRequest ,
479494) -> LC_ModelRequest :
480495 return original_request .override (
481496 system_message = LC_SystemMessage (content = request .system_message ),
@@ -504,7 +519,7 @@ def _convert_tool_message_to_lc(
504519 case SubagentMessage ():
505520 name = _normalize_agent_name (message .name )
506521 case ToolMessage ():
507- name = _normalize_tool_name (message .name )
522+ name = _normalize_tool_name (message .name , message . type )
508523
509524 return LC_ToolMessage (
510525 name = name ,
@@ -515,11 +530,10 @@ def _convert_tool_message_to_lc(
515530
516531
517532def _convert_tool_response_to_lc (
518- response : ToolResponse ,
519- call : ToolCall ,
533+ response : ToolResponse , call : ToolCall
520534) -> LC_ToolMessage :
521535 return LC_ToolMessage (
522- name = _normalize_tool_name (call .name ),
536+ name = _normalize_tool_name (call .name , call . type ),
523537 content = response .content ,
524538 tool_call_id = call .id ,
525539 status = response .status ,
@@ -554,11 +568,18 @@ def _convert_tool_message_from_lc(
554568 assert message .name is not None , (
555569 "LangChain responded with a nameless tool call"
556570 )
571+
572+ tool_type : ToolType = (
573+ ToolType .LOCAL
574+ if message .name .startswith (LOCAL_TOOL_PREFIX )
575+ else ToolType .REMOTE
576+ )
557577 return ToolMessage (
558578 name = _denormalize_tool_name (message .name ),
559579 content = message .content .__str__ (),
560580 call_id = message .tool_call_id ,
561581 status = message .status ,
582+ type = tool_type ,
562583 )
563584 case LC_Command ():
564585 # NOTE: for now the command is not implemented
@@ -668,7 +689,7 @@ async def _tool_call(**kwargs: dict[str, Any]) -> dict[str, Any] | list[str]:
668689 except ToolException as e :
669690 raise LC_ToolException (* e .args ) from e
670691 except LC_ToolException :
671- assert False , (
692+ assert False , ( # noqa: PT015
672693 "ToolException from LangChain should not be raised in tool.func"
673694 )
674695
@@ -687,7 +708,7 @@ async def _tool_call(**kwargs: dict[str, Any]) -> dict[str, Any] | list[str]:
687708 return result .content
688709
689710 return StructuredTool (
690- name = _normalize_tool_name (tool .name ),
711+ name = _normalize_tool_name (tool .name , tool . type ),
691712 description = tool .description ,
692713 args_schema = tool .input_schema ,
693714 coroutine = _tool_call ,
@@ -709,14 +730,24 @@ def _denormalize_agent_name(name: str) -> str:
709730 return name .removeprefix (AGENT_PREFIX )
710731
711732
712- def _normalize_tool_name (name : str ) -> str :
733+ def _normalize_tool_name (name : str , tool_type : ToolType ) -> str :
734+ if tool_type == ToolType .LOCAL :
735+ return LOCAL_TOOL_PREFIX + name
736+
713737 if name .startswith (RESERVED_LC_TOOL_PREFIX ):
714- return f"{ CONFLICTING_TOOL_PREFIX } { name } "
738+ # Tool name contains our reserved prefix, see comment
739+ # on CONFLICTING_TOOL_PREFIX for more details
740+ return CONFLICTING_TOOL_PREFIX + name
741+
715742 return name
716743
717744
718745def _denormalize_tool_name (name : str ) -> str :
719- return name .removeprefix (CONFLICTING_TOOL_PREFIX )
746+ if name .startswith (RESERVED_LC_TOOL_PREFIX ):
747+ assert "-" in name , "Invalid prefix in tool name"
748+ _prefix , name = name .split ("-" , maxsplit = 1 )
749+
750+ return name
720751
721752
722753def _agent_as_tool (agent : BaseAgent [OutputT ]) -> StructuredTool :
@@ -757,17 +788,22 @@ async def _run(**kwargs: dict[str, Any]) -> OutputT | str:
757788
758789
759790def _map_tool_call_from_langchain (tool_call : LC_ToolCall ) -> ToolCall | SubagentCall :
760- if tool_call ["name" ].startswith (AGENT_PREFIX ):
791+ name = tool_call ["name" ]
792+ if name .startswith (AGENT_PREFIX ):
761793 return SubagentCall (
762- name = _denormalize_agent_name (tool_call [ " name" ] ),
794+ name = _denormalize_agent_name (name ),
763795 args = tool_call ["args" ],
764796 id = tool_call ["id" ],
765797 )
766798
799+ tool_type : ToolType = (
800+ ToolType .LOCAL if name .startswith (LOCAL_TOOL_PREFIX ) else ToolType .REMOTE
801+ )
767802 return ToolCall (
768- name = _denormalize_tool_name (tool_call [ " name" ] ),
803+ name = _denormalize_tool_name (name ),
769804 args = tool_call ["args" ],
770805 id = tool_call ["id" ],
806+ type = tool_type ,
771807 )
772808
773809
@@ -776,13 +812,9 @@ def _map_tool_call_to_langchain(call: ToolCall | SubagentCall) -> LC_ToolCall:
776812 case SubagentCall ():
777813 name = _normalize_agent_name (call .name )
778814 case ToolCall ():
779- name = _normalize_tool_name (call .name )
815+ name = _normalize_tool_name (call .name , call . type )
780816
781- return LC_ToolCall (
782- name = name ,
783- args = call .args ,
784- id = call .id ,
785- )
817+ return LC_ToolCall (id = call .id , name = name , args = call .args )
786818
787819
788820def _map_message_from_langchain (message : LC_BaseMessage ) -> BaseMessage :
@@ -806,7 +838,7 @@ def _map_message_to_langchain(message: BaseMessage) -> LC_AnyMessage:
806838 match message :
807839 case AIMessage ():
808840 lc_message = LC_AIMessage (content = message .content )
809- # this field can't be set via constructor
841+ # This field can't be set via constructor
810842 lc_message .tool_calls = [
811843 _map_tool_call_to_langchain (c ) for c in message .calls
812844 ]
0 commit comments