11import json
22import threading
33import logging
4- from typing import List , Optional
4+ from typing import Any , Dict , List , Optional
55from urllib .parse import urljoin
66
77from jinja2 import Template , StrictUndefined
3737from utils .config_utils import tenant_config_manager , get_model_name_from_config
3838from utils .context_utils import build_context_components
3939from consts .const import LOCAL_MCP_SERVER , MODEL_CONFIG_MAPPING , LANGUAGE , DATA_PROCESS_SERVICE , MINIO_DEFAULT_BUCKET
40+ from consts .model import AgentToolParamsRequest , ToolParamsRequest
4041from consts .exceptions import ValidationError
4142
4243logger = logging .getLogger ("create_agent_info" )
4344logger .setLevel (logging .DEBUG )
4445
4546
47+ def _normalize_tool_params_request (tool_params : Optional [ToolParamsRequest | Dict [str , Any ]]) -> ToolParamsRequest :
48+ """Normalize request-scoped tool parameter overrides into a ToolParamsRequest."""
49+ if tool_params is None :
50+ return ToolParamsRequest ()
51+ if isinstance (tool_params , ToolParamsRequest ):
52+ return tool_params
53+ if not isinstance (tool_params , dict ):
54+ raise ValidationError ("tool_params must be an object." )
55+ try :
56+ return ToolParamsRequest .model_validate (tool_params )
57+ except Exception as exc :
58+ raise ValidationError (f"Invalid tool_params payload: { exc } " ) from exc
59+
60+
61+ def _get_agent_tool_overrides (
62+ tool_params : Optional [ToolParamsRequest ],
63+ agent_name : Optional [str ],
64+ ) -> Dict [str , Dict [str , Any ]]:
65+ """Resolve tool overrides for a specific agent by its name."""
66+ if tool_params is None :
67+ return {}
68+ if not agent_name :
69+ return {}
70+ agent_override = tool_params .agents .get (agent_name )
71+ if agent_override is None :
72+ return {}
73+ return dict (agent_override .tools )
74+
75+
76+ def _merge_tool_params (
77+ tool_record : Dict [str , Any ],
78+ override_params : Optional [Dict [str , Any ]],
79+ extra_params : Optional [Dict [str , Any ]] = None ,
80+ ) -> Dict [str , Any ]:
81+ """Merge request overrides on top of tool instance defaults from DB.
82+
83+ Args:
84+ tool_record: Tool configuration from database
85+ override_params: Request-scoped overrides from tool_params
86+ extra_params: Additional internal params not in DB schema (e.g., document_paths)
87+
88+ Returns:
89+ Merged params dict with DB defaults, overrides, and extra params
90+ """
91+ merged_params : Dict [str , Any ] = {}
92+ for param in tool_record .get ("params" , []):
93+ merged_params [param ["name" ]] = param .get ("default" )
94+
95+ if override_params :
96+ merged_params .update (override_params )
97+
98+ # Extra params (e.g., internal access control params) always take precedence
99+ if extra_params :
100+ merged_params .update (extra_params )
101+
102+ return merged_params
103+
104+
46105def _build_internal_s3_url (file : dict ) -> str :
47106 """Build a valid S3 URL for internal tools from uploaded file metadata."""
48107 if not isinstance (file , dict ):
@@ -314,7 +373,9 @@ async def create_agent_config(
314373 allow_memory_search : bool = True ,
315374 version_no : int = 0 ,
316375 override_model_id : int | None = None ,
376+ tool_params : Optional [ToolParamsRequest | Dict [str , Any ]] = None ,
317377):
378+ normalized_tool_params = _normalize_tool_params_request (tool_params )
318379 agent_info = search_agent_info_by_agent_id (
319380 agent_id = agent_id , tenant_id = tenant_id , version_no = version_no )
320381
@@ -338,13 +399,20 @@ async def create_agent_config(
338399 allow_memory_search = allow_memory_search ,
339400 version_no = sub_agent_version_no ,
340401 override_model_id = None ,
402+ tool_params = normalized_tool_params ,
341403 )
342404 managed_agents .append (sub_agent_config )
343405
344406 # create external A2A agents (synchronous function, no await needed)
345407 external_a2a_agents = _get_external_a2a_agents (agent_id , tenant_id , version_no )
346408
347- tool_list = await create_tool_config_list (agent_id , tenant_id , user_id , version_no = version_no )
409+ tool_list = await create_tool_config_list (
410+ agent_id ,
411+ tenant_id ,
412+ user_id ,
413+ version_no = version_no ,
414+ tool_params = normalized_tool_params ,
415+ )
348416
349417 # Build system prompt: prioritize segmented fields, fallback to original prompt field if not available
350418 duty_prompt = agent_info .get ("duty_prompt" , "" )
@@ -570,17 +638,43 @@ async def create_agent_config(
570638 return agent_config
571639
572640
573- async def create_tool_config_list (agent_id , tenant_id , user_id , version_no : int = 0 ):
574- # create tool
641+ async def create_tool_config_list (
642+ agent_id ,
643+ tenant_id ,
644+ user_id ,
645+ version_no : int = 0 ,
646+ tool_params : Optional [ToolParamsRequest | Dict [str , Any ]] = None ,
647+ ):
575648 tool_config_list = []
576649 langchain_tools = await discover_langchain_tools ()
650+ normalized_tool_params = _normalize_tool_params_request (tool_params )
577651
578652 # now only admin can modify the agent, user_id is not used
579653 tools_list = search_tools_for_sub_agent (agent_id , tenant_id , version_no = version_no )
654+
655+ # Look up agent name for use in error messages.
656+ # Agent name is optional for tool_params matching (matching uses tool identifiers only),
657+ # but we include it in error messages so callers can identify which agent/tool caused a failure.
658+ agent_info = search_agent_info_by_agent_id (agent_id = agent_id , tenant_id = tenant_id , version_no = version_no )
659+ agent_name = agent_info .get ("name" ) if agent_info else None
660+ agent_tool_overrides = _get_agent_tool_overrides (normalized_tool_params , agent_name )
661+
662+ tool_keys_seen = set ()
580663 for tool in tools_list :
581- param_dict = {}
582- for param in tool .get ("params" , []):
583- param_dict [param ["name" ]] = param .get ("default" )
664+ tool_identifier = tool .get ("name" ) or tool .get ("class_name" )
665+ if tool_identifier in tool_keys_seen :
666+ raise ValidationError (
667+ f"Duplicate tool identifier '{ tool_identifier } ' found in agent '{ agent_name or agent_id } '."
668+ )
669+ tool_keys_seen .add (tool_identifier )
670+
671+ override_params = None
672+ if tool .get ("name" ) in agent_tool_overrides :
673+ override_params = agent_tool_overrides [tool .get ("name" )]
674+ elif tool .get ("class_name" ) in agent_tool_overrides :
675+ override_params = agent_tool_overrides [tool .get ("class_name" )]
676+
677+ param_dict = _merge_tool_params (tool , override_params )
584678 tool_config = ToolConfig (
585679 class_name = tool .get ("class_name" ),
586680 name = tool .get ("name" ),
@@ -599,20 +693,29 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
599693 tool_config .metadata = langchain_tool
600694 break
601695
696+ # Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema)
697+ document_paths = None
698+ if override_params and "document_paths" in override_params :
699+ document_paths = override_params .get ("document_paths" )
700+ # Also check using the tool name as key
701+ if not document_paths :
702+ kb_overrides = agent_tool_overrides .get ("knowledge_base_search" )
703+ if kb_overrides and "document_paths" in kb_overrides :
704+ document_paths = kb_overrides .get ("document_paths" )
705+
602706 # special logic for search tools that may use reranking models
603707 if tool_config .class_name == "KnowledgeBaseSearchTool" :
604- rerank = param_dict .get ("rerank" , False )
605- rerank_model_name = param_dict .get ("rerank_model_name" , "" )
708+ rerank = tool_config . params .get ("rerank" , False )
709+ rerank_model_name = tool_config . params .get ("rerank_model_name" , "" )
606710 rerank_model = None
607- is_multimodal = bool (tool_config .params .pop ("multimodal" , False ))
608711 if rerank and rerank_model_name :
609712 rerank_model = get_rerank_model (
610713 tenant_id = tenant_id , model_name = rerank_model_name
611714 )
612715
613716 # Build display_name to index_name mapping for LLM parameter conversion
614717 # Also build reverse mapping (index_name -> display_name) for knowledge_base_summary
615- index_names = param_dict .get ("index_names" , [])
718+ index_names = tool_config . params .get ("index_names" , [])
616719 display_name_to_index_map = {}
617720 index_name_to_display_map = {}
618721 if index_names :
@@ -628,12 +731,14 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
628731 "rerank_model" : rerank_model ,
629732 "display_name_to_index_map" : display_name_to_index_map ,
630733 "index_name_to_display_map" : index_name_to_display_map ,
734+ # Internal access control: restrict results to specific document paths (path_or_urls)
735+ "document_paths" : document_paths ,
631736 }
632737
633- # Must have embedding model for knowledge base search
634738 if not index_names :
635739 raise ValidationError (
636- "Embedding model is required for knowledge_base_search but index_names is empty" )
740+ f"[{ agent_name or agent_id } ] knowledge_base_search tool requires index_names, "
741+ f"but it is not configured in the agent and not provided via tool_params." )
637742
638743 embedding_model , _ , _ = get_embedding_model_by_index_name (tenant_id , index_names [0 ])
639744 if not embedding_model :
@@ -642,8 +747,8 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
642747 f"Please configure an embedding model for this knowledge base." )
643748 tool_config .metadata ["embedding_model" ] = embedding_model
644749 elif tool_config .class_name in ["DifySearchTool" , "DataMateSearchTool" ]:
645- rerank = param_dict .get ("rerank" , False )
646- rerank_model_name = param_dict .get ("rerank_model_name" , "" )
750+ rerank = tool_config . params .get ("rerank" , False )
751+ rerank_model_name = tool_config . params .get ("rerank_model_name" , "" )
647752 rerank_model = None
648753 if rerank and rerank_model_name :
649754 rerank_model = get_rerank_model (
@@ -937,6 +1042,7 @@ async def create_agent_run_info(
9371042 is_debug : bool = False ,
9381043 override_version_no : int | None = None ,
9391044 override_model_id : int | None = None ,
1045+ tool_params : Optional [ToolParamsRequest | Dict [str , Any ]] = None ,
9401046):
9411047 # Determine which version_no to use based on is_debug flag
9421048 # If is_debug=false, use the current published version (current_version_no)
@@ -969,7 +1075,7 @@ async def create_agent_run_info(
9691075 if override_model_id is not None :
9701076 create_config_kwargs ["override_model_id" ] = override_model_id
9711077
972- agent_config = await create_agent_config (** create_config_kwargs )
1078+ agent_config = await create_agent_config (** create_config_kwargs , tool_params = tool_params )
9731079
9741080 remote_mcp_list = await get_remote_mcp_server_list (tenant_id = tenant_id , is_need_auth = True )
9751081 default_mcp_url = urljoin (LOCAL_MCP_SERVER , "sse" )
0 commit comments