11import json
22import threading
33import logging
4- from typing import List , Optional
4+ from typing import Any , Dict , List , Optional
55from urllib .parse import urljoin
66
77from jinja2 import Template , StrictUndefined
3333from utils .config_utils import tenant_config_manager , get_model_name_from_config
3434from utils .context_utils import build_context_components
3535from consts .const import LOCAL_MCP_SERVER , MODEL_CONFIG_MAPPING , LANGUAGE , DATA_PROCESS_SERVICE , MINIO_DEFAULT_BUCKET
36+ from consts .model import AgentToolParamsRequest , ToolParamsRequest
3637from consts .exceptions import ValidationError
3738
3839logger = logging .getLogger ("create_agent_info" )
3940logger .setLevel (logging .DEBUG )
4041
4142
43+ def _normalize_tool_params_request (tool_params : Optional [ToolParamsRequest | Dict [str , Any ]]) -> ToolParamsRequest :
44+ """Normalize request-scoped tool parameter overrides into a ToolParamsRequest."""
45+ if tool_params is None :
46+ return ToolParamsRequest ()
47+ if isinstance (tool_params , ToolParamsRequest ):
48+ return tool_params
49+ if not isinstance (tool_params , dict ):
50+ raise ValidationError ("tool_params must be an object." )
51+ try :
52+ return ToolParamsRequest .model_validate (tool_params )
53+ except Exception as exc :
54+ raise ValidationError (f"Invalid tool_params payload: { exc } " ) from exc
55+
56+
57+ def _get_agent_tool_overrides (
58+ tool_params : Optional [ToolParamsRequest ],
59+ agent_name : Optional [str ],
60+ ) -> Dict [str , Dict [str , Any ]]:
61+ """Resolve tool overrides for a specific agent by its name."""
62+ if tool_params is None :
63+ return {}
64+ if not agent_name :
65+ return {}
66+ agent_override = tool_params .agents .get (agent_name )
67+ if agent_override is None :
68+ return {}
69+ return dict (agent_override .tools )
70+
71+
72+ def _merge_tool_params (
73+ tool_record : Dict [str , Any ],
74+ override_params : Optional [Dict [str , Any ]],
75+ extra_params : Optional [Dict [str , Any ]] = None ,
76+ ) -> Dict [str , Any ]:
77+ """Merge request overrides on top of tool instance defaults from DB.
78+
79+ Args:
80+ tool_record: Tool configuration from database
81+ override_params: Request-scoped overrides from tool_params
82+ extra_params: Additional internal params not in DB schema (e.g., document_paths)
83+
84+ Returns:
85+ Merged params dict with DB defaults, overrides, and extra params
86+ """
87+ merged_params : Dict [str , Any ] = {}
88+ for param in tool_record .get ("params" , []):
89+ merged_params [param ["name" ]] = param .get ("default" )
90+
91+ if override_params :
92+ merged_params .update (override_params )
93+
94+ # Extra params (e.g., internal access control params) always take precedence
95+ if extra_params :
96+ merged_params .update (extra_params )
97+
98+ return merged_params
99+
100+
42101def _build_internal_s3_url (file : dict ) -> str :
43102 """Build a valid S3 URL for internal tools from uploaded file metadata."""
44103 if not isinstance (file , dict ):
@@ -310,7 +369,9 @@ async def create_agent_config(
310369 allow_memory_search : bool = True ,
311370 version_no : int = 0 ,
312371 override_model_id : int | None = None ,
372+ tool_params : Optional [ToolParamsRequest | Dict [str , Any ]] = None ,
313373):
374+ normalized_tool_params = _normalize_tool_params_request (tool_params )
314375 agent_info = search_agent_info_by_agent_id (
315376 agent_id = agent_id , tenant_id = tenant_id , version_no = version_no )
316377
@@ -331,13 +392,20 @@ async def create_agent_config(
331392 allow_memory_search = allow_memory_search ,
332393 version_no = sub_agent_version_no ,
333394 override_model_id = None ,
395+ tool_params = normalized_tool_params ,
334396 )
335397 managed_agents .append (sub_agent_config )
336398
337399 # create external A2A agents (synchronous function, no await needed)
338400 external_a2a_agents = _get_external_a2a_agents (agent_id , tenant_id , version_no )
339401
340- tool_list = await create_tool_config_list (agent_id , tenant_id , user_id , version_no = version_no )
402+ tool_list = await create_tool_config_list (
403+ agent_id ,
404+ tenant_id ,
405+ user_id ,
406+ version_no = version_no ,
407+ tool_params = normalized_tool_params ,
408+ )
341409
342410 # Build system prompt: prioritize segmented fields, fallback to original prompt field if not available
343411 duty_prompt = agent_info .get ("duty_prompt" , "" )
@@ -562,17 +630,43 @@ async def create_agent_config(
562630 return agent_config
563631
564632
565- async def create_tool_config_list (agent_id , tenant_id , user_id , version_no : int = 0 ):
566- # create tool
633+ async def create_tool_config_list (
634+ agent_id ,
635+ tenant_id ,
636+ user_id ,
637+ version_no : int = 0 ,
638+ tool_params : Optional [ToolParamsRequest | Dict [str , Any ]] = None ,
639+ ):
567640 tool_config_list = []
568641 langchain_tools = await discover_langchain_tools ()
642+ normalized_tool_params = _normalize_tool_params_request (tool_params )
569643
570644 # now only admin can modify the agent, user_id is not used
571645 tools_list = search_tools_for_sub_agent (agent_id , tenant_id , version_no = version_no )
646+
647+ # Look up agent name for use in error messages.
648+ # Agent name is optional for tool_params matching (matching uses tool identifiers only),
649+ # but we include it in error messages so callers can identify which agent/tool caused a failure.
650+ agent_info = search_agent_info_by_agent_id (agent_id = agent_id , tenant_id = tenant_id , version_no = version_no )
651+ agent_name = agent_info .get ("name" ) if agent_info else None
652+ agent_tool_overrides = _get_agent_tool_overrides (normalized_tool_params , agent_name )
653+
654+ tool_keys_seen = set ()
572655 for tool in tools_list :
573- param_dict = {}
574- for param in tool .get ("params" , []):
575- param_dict [param ["name" ]] = param .get ("default" )
656+ tool_identifier = tool .get ("name" ) or tool .get ("class_name" )
657+ if tool_identifier in tool_keys_seen :
658+ raise ValidationError (
659+ f"Duplicate tool identifier '{ tool_identifier } ' found in agent '{ agent_name or agent_id } '."
660+ )
661+ tool_keys_seen .add (tool_identifier )
662+
663+ override_params = None
664+ if tool .get ("name" ) in agent_tool_overrides :
665+ override_params = agent_tool_overrides [tool .get ("name" )]
666+ elif tool .get ("class_name" ) in agent_tool_overrides :
667+ override_params = agent_tool_overrides [tool .get ("class_name" )]
668+
669+ param_dict = _merge_tool_params (tool , override_params )
576670 tool_config = ToolConfig (
577671 class_name = tool .get ("class_name" ),
578672 name = tool .get ("name" ),
@@ -591,20 +685,29 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
591685 tool_config .metadata = langchain_tool
592686 break
593687
688+ # Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema)
689+ document_paths = None
690+ if override_params and "document_paths" in override_params :
691+ document_paths = override_params .get ("document_paths" )
692+ # Also check using the tool name as key
693+ if not document_paths :
694+ kb_overrides = agent_tool_overrides .get ("knowledge_base_search" )
695+ if kb_overrides and "document_paths" in kb_overrides :
696+ document_paths = kb_overrides .get ("document_paths" )
697+
594698 # special logic for search tools that may use reranking models
595699 if tool_config .class_name == "KnowledgeBaseSearchTool" :
596- rerank = param_dict .get ("rerank" , False )
597- rerank_model_name = param_dict .get ("rerank_model_name" , "" )
700+ rerank = tool_config . params .get ("rerank" , False )
701+ rerank_model_name = tool_config . params .get ("rerank_model_name" , "" )
598702 rerank_model = None
599- is_multimodal = bool (tool_config .params .pop ("multimodal" , False ))
600703 if rerank and rerank_model_name :
601704 rerank_model = get_rerank_model (
602705 tenant_id = tenant_id , model_name = rerank_model_name
603706 )
604707
605708 # Build display_name to index_name mapping for LLM parameter conversion
606709 # Also build reverse mapping (index_name -> display_name) for knowledge_base_summary
607- index_names = param_dict .get ("index_names" , [])
710+ index_names = tool_config . params .get ("index_names" , [])
608711 display_name_to_index_map = {}
609712 index_name_to_display_map = {}
610713 if index_names :
@@ -620,12 +723,14 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
620723 "rerank_model" : rerank_model ,
621724 "display_name_to_index_map" : display_name_to_index_map ,
622725 "index_name_to_display_map" : index_name_to_display_map ,
726+ # Internal access control: restrict results to specific document paths (path_or_urls)
727+ "document_paths" : document_paths ,
623728 }
624729
625- # Must have embedding model for knowledge base search
626730 if not index_names :
627731 raise ValidationError (
628- "Embedding model is required for knowledge_base_search but index_names is empty" )
732+ f"[{ agent_name or agent_id } ] knowledge_base_search tool requires index_names, "
733+ f"but it is not configured in the agent and not provided via tool_params." )
629734
630735 embedding_model , _ , _ = get_embedding_model_by_index_name (tenant_id , index_names [0 ])
631736 if not embedding_model :
@@ -634,8 +739,8 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
634739 f"Please configure an embedding model for this knowledge base." )
635740 tool_config .metadata ["embedding_model" ] = embedding_model
636741 elif tool_config .class_name in ["DifySearchTool" , "DataMateSearchTool" ]:
637- rerank = param_dict .get ("rerank" , False )
638- rerank_model_name = param_dict .get ("rerank_model_name" , "" )
742+ rerank = tool_config . params .get ("rerank" , False )
743+ rerank_model_name = tool_config . params .get ("rerank_model_name" , "" )
639744 rerank_model = None
640745 if rerank and rerank_model_name :
641746 rerank_model = get_rerank_model (
@@ -929,6 +1034,7 @@ async def create_agent_run_info(
9291034 is_debug : bool = False ,
9301035 override_version_no : int | None = None ,
9311036 override_model_id : int | None = None ,
1037+ tool_params : Optional [ToolParamsRequest | Dict [str , Any ]] = None ,
9321038):
9331039 # Determine which version_no to use based on is_debug flag
9341040 # If is_debug=false, use the current published version (current_version_no)
@@ -961,7 +1067,7 @@ async def create_agent_run_info(
9611067 if override_model_id is not None :
9621068 create_config_kwargs ["override_model_id" ] = override_model_id
9631069
964- agent_config = await create_agent_config (** create_config_kwargs )
1070+ agent_config = await create_agent_config (** create_config_kwargs , tool_params = tool_params )
9651071
9661072 remote_mcp_list = await get_remote_mcp_server_list (tenant_id = tenant_id , is_need_auth = True )
9671073 default_mcp_url = urljoin (LOCAL_MCP_SERVER , "sse" )
0 commit comments