Skip to content

Commit c38bdd8

Browse files
committed
✨ Add several northbound apis
1 parent b6b6027 commit c38bdd8

12 files changed

Lines changed: 1984 additions & 616 deletions

backend/agents/create_agent_info.py

Lines changed: 122 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import threading
33
import logging
4-
from typing import List, Optional
4+
from typing import Any, Dict, List, Optional
55
from urllib.parse import urljoin
66

77
from jinja2 import Template, StrictUndefined
@@ -33,12 +33,71 @@
3333
from utils.config_utils import tenant_config_manager, get_model_name_from_config
3434
from utils.context_utils import build_context_components
3535
from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET
36+
from consts.model import AgentToolParamsRequest, ToolParamsRequest
3637
from consts.exceptions import ValidationError
3738

3839
logger = logging.getLogger("create_agent_info")
3940
logger.setLevel(logging.DEBUG)
4041

4142

43+
def _normalize_tool_params_request(tool_params: Optional[ToolParamsRequest | Dict[str, Any]]) -> ToolParamsRequest:
44+
"""Normalize request-scoped tool parameter overrides into a ToolParamsRequest."""
45+
if tool_params is None:
46+
return ToolParamsRequest()
47+
if isinstance(tool_params, ToolParamsRequest):
48+
return tool_params
49+
if not isinstance(tool_params, dict):
50+
raise ValidationError("tool_params must be an object.")
51+
try:
52+
return ToolParamsRequest.model_validate(tool_params)
53+
except Exception as exc:
54+
raise ValidationError(f"Invalid tool_params payload: {exc}") from exc
55+
56+
57+
def _get_agent_tool_overrides(
58+
tool_params: Optional[ToolParamsRequest],
59+
agent_name: Optional[str],
60+
) -> Dict[str, Dict[str, Any]]:
61+
"""Resolve tool overrides for a specific agent by its name."""
62+
if tool_params is None:
63+
return {}
64+
if not agent_name:
65+
return {}
66+
agent_override = tool_params.agents.get(agent_name)
67+
if agent_override is None:
68+
return {}
69+
return dict(agent_override.tools)
70+
71+
72+
def _merge_tool_params(
73+
tool_record: Dict[str, Any],
74+
override_params: Optional[Dict[str, Any]],
75+
extra_params: Optional[Dict[str, Any]] = None,
76+
) -> Dict[str, Any]:
77+
"""Merge request overrides on top of tool instance defaults from DB.
78+
79+
Args:
80+
tool_record: Tool configuration from database
81+
override_params: Request-scoped overrides from tool_params
82+
extra_params: Additional internal params not in DB schema (e.g., document_paths)
83+
84+
Returns:
85+
Merged params dict with DB defaults, overrides, and extra params
86+
"""
87+
merged_params: Dict[str, Any] = {}
88+
for param in tool_record.get("params", []):
89+
merged_params[param["name"]] = param.get("default")
90+
91+
if override_params:
92+
merged_params.update(override_params)
93+
94+
# Extra params (e.g., internal access control params) always take precedence
95+
if extra_params:
96+
merged_params.update(extra_params)
97+
98+
return merged_params
99+
100+
42101
def _build_internal_s3_url(file: dict) -> str:
43102
"""Build a valid S3 URL for internal tools from uploaded file metadata."""
44103
if not isinstance(file, dict):
@@ -310,7 +369,9 @@ async def create_agent_config(
310369
allow_memory_search: bool = True,
311370
version_no: int = 0,
312371
override_model_id: int | None = None,
372+
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
313373
):
374+
normalized_tool_params = _normalize_tool_params_request(tool_params)
314375
agent_info = search_agent_info_by_agent_id(
315376
agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)
316377

@@ -331,13 +392,20 @@ async def create_agent_config(
331392
allow_memory_search=allow_memory_search,
332393
version_no=sub_agent_version_no,
333394
override_model_id=None,
395+
tool_params=normalized_tool_params,
334396
)
335397
managed_agents.append(sub_agent_config)
336398

337399
# create external A2A agents (synchronous function, no await needed)
338400
external_a2a_agents = _get_external_a2a_agents(agent_id, tenant_id, version_no)
339401

340-
tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no)
402+
tool_list = await create_tool_config_list(
403+
agent_id,
404+
tenant_id,
405+
user_id,
406+
version_no=version_no,
407+
tool_params=normalized_tool_params,
408+
)
341409

342410
# Build system prompt: prioritize segmented fields, fallback to original prompt field if not available
343411
duty_prompt = agent_info.get("duty_prompt", "")
@@ -562,17 +630,43 @@ async def create_agent_config(
562630
return agent_config
563631

564632

565-
async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0):
566-
# create tool
633+
async def create_tool_config_list(
634+
agent_id,
635+
tenant_id,
636+
user_id,
637+
version_no: int = 0,
638+
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
639+
):
567640
tool_config_list = []
568641
langchain_tools = await discover_langchain_tools()
642+
normalized_tool_params = _normalize_tool_params_request(tool_params)
569643

570644
# now only admin can modify the agent, user_id is not used
571645
tools_list = search_tools_for_sub_agent(agent_id, tenant_id, version_no=version_no)
646+
647+
# Look up agent name for use in error messages.
648+
# Agent name is optional for tool_params matching (matching uses tool identifiers only),
649+
# but we include it in error messages so callers can identify which agent/tool caused a failure.
650+
agent_info = search_agent_info_by_agent_id(agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)
651+
agent_name = agent_info.get("name") if agent_info else None
652+
agent_tool_overrides = _get_agent_tool_overrides(normalized_tool_params, agent_name)
653+
654+
tool_keys_seen = set()
572655
for tool in tools_list:
573-
param_dict = {}
574-
for param in tool.get("params", []):
575-
param_dict[param["name"]] = param.get("default")
656+
tool_identifier = tool.get("name") or tool.get("class_name")
657+
if tool_identifier in tool_keys_seen:
658+
raise ValidationError(
659+
f"Duplicate tool identifier '{tool_identifier}' found in agent '{agent_name or agent_id}'."
660+
)
661+
tool_keys_seen.add(tool_identifier)
662+
663+
override_params = None
664+
if tool.get("name") in agent_tool_overrides:
665+
override_params = agent_tool_overrides[tool.get("name")]
666+
elif tool.get("class_name") in agent_tool_overrides:
667+
override_params = agent_tool_overrides[tool.get("class_name")]
668+
669+
param_dict = _merge_tool_params(tool, override_params)
576670
tool_config = ToolConfig(
577671
class_name=tool.get("class_name"),
578672
name=tool.get("name"),
@@ -591,20 +685,29 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
591685
tool_config.metadata = langchain_tool
592686
break
593687

688+
# Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema)
689+
document_paths = None
690+
if override_params and "document_paths" in override_params:
691+
document_paths = override_params.get("document_paths")
692+
# Also check using the tool name as key
693+
if not document_paths:
694+
kb_overrides = agent_tool_overrides.get("knowledge_base_search")
695+
if kb_overrides and "document_paths" in kb_overrides:
696+
document_paths = kb_overrides.get("document_paths")
697+
594698
# special logic for search tools that may use reranking models
595699
if tool_config.class_name == "KnowledgeBaseSearchTool":
596-
rerank = param_dict.get("rerank", False)
597-
rerank_model_name = param_dict.get("rerank_model_name", "")
700+
rerank = tool_config.params.get("rerank", False)
701+
rerank_model_name = tool_config.params.get("rerank_model_name", "")
598702
rerank_model = None
599-
is_multimodal = bool(tool_config.params.pop("multimodal", False))
600703
if rerank and rerank_model_name:
601704
rerank_model = get_rerank_model(
602705
tenant_id=tenant_id, model_name=rerank_model_name
603706
)
604707

605708
# Build display_name to index_name mapping for LLM parameter conversion
606709
# Also build reverse mapping (index_name -> display_name) for knowledge_base_summary
607-
index_names = param_dict.get("index_names", [])
710+
index_names = tool_config.params.get("index_names", [])
608711
display_name_to_index_map = {}
609712
index_name_to_display_map = {}
610713
if index_names:
@@ -620,12 +723,14 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
620723
"rerank_model": rerank_model,
621724
"display_name_to_index_map": display_name_to_index_map,
622725
"index_name_to_display_map": index_name_to_display_map,
726+
# Internal access control: restrict results to specific document paths (path_or_urls)
727+
"document_paths": document_paths,
623728
}
624729

625-
# Must have embedding model for knowledge base search
626730
if not index_names:
627731
raise ValidationError(
628-
"Embedding model is required for knowledge_base_search but index_names is empty")
732+
f"[{agent_name or agent_id}] knowledge_base_search tool requires index_names, "
733+
f"but it is not configured in the agent and not provided via tool_params.")
629734

630735
embedding_model, _, _ = get_embedding_model_by_index_name(tenant_id, index_names[0])
631736
if not embedding_model:
@@ -634,8 +739,8 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
634739
f"Please configure an embedding model for this knowledge base.")
635740
tool_config.metadata["embedding_model"] = embedding_model
636741
elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]:
637-
rerank = param_dict.get("rerank", False)
638-
rerank_model_name = param_dict.get("rerank_model_name", "")
742+
rerank = tool_config.params.get("rerank", False)
743+
rerank_model_name = tool_config.params.get("rerank_model_name", "")
639744
rerank_model = None
640745
if rerank and rerank_model_name:
641746
rerank_model = get_rerank_model(
@@ -929,6 +1034,7 @@ async def create_agent_run_info(
9291034
is_debug: bool = False,
9301035
override_version_no: int | None = None,
9311036
override_model_id: int | None = None,
1037+
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
9321038
):
9331039
# Determine which version_no to use based on is_debug flag
9341040
# If is_debug=false, use the current published version (current_version_no)
@@ -961,7 +1067,7 @@ async def create_agent_run_info(
9611067
if override_model_id is not None:
9621068
create_config_kwargs["override_model_id"] = override_model_id
9631069

964-
agent_config = await create_agent_config(**create_config_kwargs)
1070+
agent_config = await create_agent_config(**create_config_kwargs, tool_params=tool_params)
9651071

9661072
remote_mcp_list = await get_remote_mcp_server_list(tenant_id=tenant_id, is_need_auth=True)
9671073
default_mcp_url = urljoin(LOCAL_MCP_SERVER, "sse")

0 commit comments

Comments
 (0)