Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 122 additions & 16 deletions backend/agents/create_agent_info.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import threading
import logging
from typing import List, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import urljoin

from jinja2 import Template, StrictUndefined
Expand Down Expand Up @@ -37,12 +37,71 @@
from utils.config_utils import tenant_config_manager, get_model_name_from_config
from utils.context_utils import build_context_components
from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET
from consts.model import AgentToolParamsRequest, ToolParamsRequest
from consts.exceptions import ValidationError

logger = logging.getLogger("create_agent_info")
logger.setLevel(logging.DEBUG)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_normalize_tool_params_request 使用 ToolParamsRequest.model_validate 验证输入,但如果 tool_params 包含未知的 tool name,不会报错。建议在 _get_agent_tool_overrides 中添加验证,确保 agent name 存在于当前 agent 列表中,避免无效覆盖。


def _normalize_tool_params_request(tool_params: Optional[ToolParamsRequest | Dict[str, Any]]) -> ToolParamsRequest:
"""Normalize request-scoped tool parameter overrides into a ToolParamsRequest."""
if tool_params is None:
return ToolParamsRequest()
if isinstance(tool_params, ToolParamsRequest):
return tool_params
if not isinstance(tool_params, dict):
raise ValidationError("tool_params must be an object.")
try:
return ToolParamsRequest.model_validate(tool_params)
except Exception as exc:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[代码规范] except Exception: 过于宽泛,建议捕获更具体的异常类型,避免掩盖潜在错误。

raise ValidationError(f"Invalid tool_params payload: {exc}") from exc


def _get_agent_tool_overrides(
tool_params: Optional[ToolParamsRequest],
agent_name: Optional[str],
) -> Dict[str, Dict[str, Any]]:
"""Resolve tool overrides for a specific agent by its name."""
if tool_params is None:
return {}
if not agent_name:
return {}
agent_override = tool_params.agents.get(agent_name)
if agent_override is None:
return {}
return dict(agent_override.tools)


def _merge_tool_params(
tool_record: Dict[str, Any],
override_params: Optional[Dict[str, Any]],
extra_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Merge request overrides on top of tool instance defaults from DB.

Args:
tool_record: Tool configuration from database
override_params: Request-scoped overrides from tool_params
extra_params: Additional internal params not in DB schema (e.g., document_paths)

Returns:
Merged params dict with DB defaults, overrides, and extra params
"""
merged_params: Dict[str, Any] = {}
for param in tool_record.get("params", []):
merged_params[param["name"]] = param.get("default")

if override_params:
merged_params.update(override_params)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] tool_params 的 override_params 直接 update 到工具参数,没有校验 key 是否属于该工具 schema。北向调用者可以注入隐藏参数或覆盖内部字段;需要按工具声明白名单过滤。

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] tool_params 的 override_params 直接 update 到工具参数,没有校验 key 是否属于该工具 schema。北向调用者可以注入隐藏参数或覆盖内部字段;需要按工具声明白名单过滤。

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

事后审查补充:[P1] tool_params 的 override_params 直接 update 到工具参数,没有校验 key 是否属于该工具 schema。北向调用者可以注入隐藏参数或覆盖内部字段;需要按工具声明白名单过滤。

影响:这个问题合入后会在对应部署、运行或权限场景中留下真实故障/安全风险,后续排查成本较高。
建议:沿着上述风险点补齐校验、配置来源、权限边界或回归测试,避免同类问题再次出现。

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

二次事后审查补充:[P1] tool_params 的 override_params 直接 update 到工具参数,没有校验 key 是否属于该工具 schema。北向调用者可以注入隐藏参数或覆盖内部字段;需要按工具声明白名单过滤。

影响:该问题合入后仍可能在真实部署、运行、权限或测试场景中形成回归风险。
建议:后续按这个风险点补齐边界校验、配置来源收敛、权限约束或针对性回归测试。


# Extra params (e.g., internal access control params) always take precedence
if extra_params:
merged_params.update(extra_params)

return merged_params


def _build_internal_s3_url(file: dict) -> str:
"""Build a valid S3 URL for internal tools from uploaded file metadata."""
if not isinstance(file, dict):
Expand Down Expand Up @@ -314,7 +373,9 @@
allow_memory_search: bool = True,
version_no: int = 0,
override_model_id: int | None = None,
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
):
normalized_tool_params = _normalize_tool_params_request(tool_params)
agent_info = search_agent_info_by_agent_id(
agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)

Expand All @@ -338,13 +399,20 @@
allow_memory_search=allow_memory_search,
version_no=sub_agent_version_no,
override_model_id=None,
tool_params=normalized_tool_params,
)
managed_agents.append(sub_agent_config)

# create external A2A agents (synchronous function, no await needed)
external_a2a_agents = _get_external_a2a_agents(agent_id, tenant_id, version_no)

tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no)
tool_list = await create_tool_config_list(
agent_id,
tenant_id,
user_id,
version_no=version_no,
tool_params=normalized_tool_params,
)

# Build system prompt: prioritize segmented fields, fallback to original prompt field if not available
duty_prompt = agent_info.get("duty_prompt", "")
Expand Down Expand Up @@ -569,17 +637,43 @@
return agent_config


async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0):
# create tool
async def create_tool_config_list(

Check failure on line 640 in backend/agents/create_agent_info.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this function to reduce its Cognitive Complexity from 55 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=ModelEngine-Group_nexent&issues=AZ62g2Y_t7nt5E6mViji&open=AZ62g2Y_t7nt5E6mViji&pullRequest=3223
agent_id,
tenant_id,
user_id,
version_no: int = 0,
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
):
tool_config_list = []
langchain_tools = await discover_langchain_tools()
normalized_tool_params = _normalize_tool_params_request(tool_params)

# now only admin can modify the agent, user_id is not used
tools_list = search_tools_for_sub_agent(agent_id, tenant_id, version_no=version_no)

# Look up agent name for use in error messages.
# Agent name is optional for tool_params matching (matching uses tool identifiers only),
# but we include it in error messages so callers can identify which agent/tool caused a failure.
agent_info = search_agent_info_by_agent_id(agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)
agent_name = agent_info.get("name") if agent_info else None
agent_tool_overrides = _get_agent_tool_overrides(normalized_tool_params, agent_name)

tool_keys_seen = set()
for tool in tools_list:
param_dict = {}
for param in tool.get("params", []):
param_dict[param["name"]] = param.get("default")
tool_identifier = tool.get("name") or tool.get("class_name")
if tool_identifier in tool_keys_seen:
raise ValidationError(
f"Duplicate tool identifier '{tool_identifier}' found in agent '{agent_name or agent_id}'."
)
tool_keys_seen.add(tool_identifier)

override_params = None
if tool.get("name") in agent_tool_overrides:
override_params = agent_tool_overrides[tool.get("name")]
elif tool.get("class_name") in agent_tool_overrides:
override_params = agent_tool_overrides[tool.get("class_name")]

param_dict = _merge_tool_params(tool, override_params)
tool_config = ToolConfig(
class_name=tool.get("class_name"),
name=tool.get("name"),
Expand All @@ -598,20 +692,29 @@
tool_config.metadata = langchain_tool
break

# Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema)
document_paths = None
if override_params and "document_paths" in override_params:
document_paths = override_params.get("document_paths")
# Also check using the tool name as key
if not document_paths:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] if not document_paths 会把空列表当成“没有限制”,然后继续回退到 knowledge_base_search 覆盖或不加过滤。对访问控制来说空列表应表示不允许任何文档,不能等同 None。

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] if not document_paths 会把空列表当成“没有限制”,然后继续回退到 knowledge_base_search 覆盖或不加过滤。对访问控制来说空列表应表示不允许任何文档,不能等同 None。

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

事后审查补充:[P1] if not document_paths 会把空列表当成“没有限制”,然后继续回退到 knowledge_base_search 覆盖或不加过滤。对访问控制来说空列表应表示不允许任何文档,不能等同 None。

影响:这个问题合入后会在对应部署、运行或权限场景中留下真实故障/安全风险,后续排查成本较高。
建议:沿着上述风险点补齐校验、配置来源、权限边界或回归测试,避免同类问题再次出现。

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

二次事后审查补充:[P1] if not document_paths 会把空列表当成“没有限制”,然后继续回退到 knowledge_base_search 覆盖或不加过滤。对访问控制来说空列表应表示不允许任何文档,不能等同 None。

影响:该问题合入后仍可能在真实部署、运行、权限或测试场景中形成回归风险。
建议:后续按这个风险点补齐边界校验、配置来源收敛、权限约束或针对性回归测试。

kb_overrides = agent_tool_overrides.get("knowledge_base_search")
if kb_overrides and "document_paths" in kb_overrides:
document_paths = kb_overrides.get("document_paths")

# special logic for search tools that may use reranking models
if tool_config.class_name == "KnowledgeBaseSearchTool":
rerank = param_dict.get("rerank", False)
rerank_model_name = param_dict.get("rerank_model_name", "")
rerank = tool_config.params.get("rerank", False)
rerank_model_name = tool_config.params.get("rerank_model_name", "")
rerank_model = None
is_multimodal = bool(tool_config.params.pop("multimodal", False))
if rerank and rerank_model_name:
rerank_model = get_rerank_model(
tenant_id=tenant_id, model_name=rerank_model_name
)

# Build display_name to index_name mapping for LLM parameter conversion
# Also build reverse mapping (index_name -> display_name) for knowledge_base_summary
index_names = param_dict.get("index_names", [])
index_names = tool_config.params.get("index_names", [])
display_name_to_index_map = {}
index_name_to_display_map = {}
if index_names:
Expand All @@ -627,12 +730,14 @@
"rerank_model": rerank_model,
"display_name_to_index_map": display_name_to_index_map,
"index_name_to_display_map": index_name_to_display_map,
# Internal access control: restrict results to specific document paths (path_or_urls)
"document_paths": document_paths,
}

# Must have embedding model for knowledge base search
if not index_names:
raise ValidationError(
"Embedding model is required for knowledge_base_search but index_names is empty")
f"[{agent_name or agent_id}] knowledge_base_search tool requires index_names, "
f"but it is not configured in the agent and not provided via tool_params.")

embedding_model, _, _ = get_embedding_model_by_index_name(tenant_id, index_names[0])
if not embedding_model:
Expand All @@ -641,8 +746,8 @@
f"Please configure an embedding model for this knowledge base.")
tool_config.metadata["embedding_model"] = embedding_model
elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]:
rerank = param_dict.get("rerank", False)
rerank_model_name = param_dict.get("rerank_model_name", "")
rerank = tool_config.params.get("rerank", False)
rerank_model_name = tool_config.params.get("rerank_model_name", "")
rerank_model = None
if rerank and rerank_model_name:
rerank_model = get_rerank_model(
Expand Down Expand Up @@ -936,6 +1041,7 @@
is_debug: bool = False,
override_version_no: int | None = None,
override_model_id: int | None = None,
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
):
# Determine which version_no to use based on is_debug flag
# If is_debug=false, use the current published version (current_version_no)
Expand Down Expand Up @@ -968,7 +1074,7 @@
if override_model_id is not None:
create_config_kwargs["override_model_id"] = override_model_id

agent_config = await create_agent_config(**create_config_kwargs)
agent_config = await create_agent_config(**create_config_kwargs, tool_params=tool_params)

remote_mcp_list = await get_remote_mcp_server_list(tenant_id=tenant_id, is_need_auth=True)
default_mcp_url = urljoin(LOCAL_MCP_SERVER, "sse")
Expand Down
Loading
Loading