Skip to content

Commit 4a0e4ee

Browse files
authored
✨ Add several northbound apis (#3223)
* ✨ Add several northbound apis * ✨ Add several northbound apis * ✨ Add several northbound apis * ✨ Add several northbound apis * ✨ Add several northbound apis
1 parent 7be83d7 commit 4a0e4ee

14 files changed

Lines changed: 3172 additions & 763 deletions

backend/agents/create_agent_info.py

Lines changed: 122 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import threading
33
import logging
4-
from typing import List, Optional
4+
from typing import Any, Dict, List, Optional
55
from urllib.parse import urljoin
66

77
from jinja2 import Template, StrictUndefined
@@ -37,12 +37,71 @@
3737
from utils.config_utils import tenant_config_manager, get_model_name_from_config
3838
from utils.context_utils import build_context_components
3939
from consts.const import LOCAL_MCP_SERVER, MODEL_CONFIG_MAPPING, LANGUAGE, DATA_PROCESS_SERVICE, MINIO_DEFAULT_BUCKET
40+
from consts.model import AgentToolParamsRequest, ToolParamsRequest
4041
from consts.exceptions import ValidationError
4142

4243
logger = logging.getLogger("create_agent_info")
4344
logger.setLevel(logging.DEBUG)
4445

4546

47+
def _normalize_tool_params_request(tool_params: Optional[ToolParamsRequest | Dict[str, Any]]) -> ToolParamsRequest:
48+
"""Normalize request-scoped tool parameter overrides into a ToolParamsRequest."""
49+
if tool_params is None:
50+
return ToolParamsRequest()
51+
if isinstance(tool_params, ToolParamsRequest):
52+
return tool_params
53+
if not isinstance(tool_params, dict):
54+
raise ValidationError("tool_params must be an object.")
55+
try:
56+
return ToolParamsRequest.model_validate(tool_params)
57+
except Exception as exc:
58+
raise ValidationError(f"Invalid tool_params payload: {exc}") from exc
59+
60+
61+
def _get_agent_tool_overrides(
62+
tool_params: Optional[ToolParamsRequest],
63+
agent_name: Optional[str],
64+
) -> Dict[str, Dict[str, Any]]:
65+
"""Resolve tool overrides for a specific agent by its name."""
66+
if tool_params is None:
67+
return {}
68+
if not agent_name:
69+
return {}
70+
agent_override = tool_params.agents.get(agent_name)
71+
if agent_override is None:
72+
return {}
73+
return dict(agent_override.tools)
74+
75+
76+
def _merge_tool_params(
77+
tool_record: Dict[str, Any],
78+
override_params: Optional[Dict[str, Any]],
79+
extra_params: Optional[Dict[str, Any]] = None,
80+
) -> Dict[str, Any]:
81+
"""Merge request overrides on top of tool instance defaults from DB.
82+
83+
Args:
84+
tool_record: Tool configuration from database
85+
override_params: Request-scoped overrides from tool_params
86+
extra_params: Additional internal params not in DB schema (e.g., document_paths)
87+
88+
Returns:
89+
Merged params dict with DB defaults, overrides, and extra params
90+
"""
91+
merged_params: Dict[str, Any] = {}
92+
for param in tool_record.get("params", []):
93+
merged_params[param["name"]] = param.get("default")
94+
95+
if override_params:
96+
merged_params.update(override_params)
97+
98+
# Extra params (e.g., internal access control params) always take precedence
99+
if extra_params:
100+
merged_params.update(extra_params)
101+
102+
return merged_params
103+
104+
46105
def _build_internal_s3_url(file: dict) -> str:
47106
"""Build a valid S3 URL for internal tools from uploaded file metadata."""
48107
if not isinstance(file, dict):
@@ -314,7 +373,9 @@ async def create_agent_config(
314373
allow_memory_search: bool = True,
315374
version_no: int = 0,
316375
override_model_id: int | None = None,
376+
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
317377
):
378+
normalized_tool_params = _normalize_tool_params_request(tool_params)
318379
agent_info = search_agent_info_by_agent_id(
319380
agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)
320381

@@ -338,13 +399,20 @@ async def create_agent_config(
338399
allow_memory_search=allow_memory_search,
339400
version_no=sub_agent_version_no,
340401
override_model_id=None,
402+
tool_params=normalized_tool_params,
341403
)
342404
managed_agents.append(sub_agent_config)
343405

344406
# create external A2A agents (synchronous function, no await needed)
345407
external_a2a_agents = _get_external_a2a_agents(agent_id, tenant_id, version_no)
346408

347-
tool_list = await create_tool_config_list(agent_id, tenant_id, user_id, version_no=version_no)
409+
tool_list = await create_tool_config_list(
410+
agent_id,
411+
tenant_id,
412+
user_id,
413+
version_no=version_no,
414+
tool_params=normalized_tool_params,
415+
)
348416

349417
# Build system prompt: prioritize segmented fields, fallback to original prompt field if not available
350418
duty_prompt = agent_info.get("duty_prompt", "")
@@ -570,17 +638,43 @@ async def create_agent_config(
570638
return agent_config
571639

572640

573-
async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int = 0):
574-
# create tool
641+
async def create_tool_config_list(
642+
agent_id,
643+
tenant_id,
644+
user_id,
645+
version_no: int = 0,
646+
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
647+
):
575648
tool_config_list = []
576649
langchain_tools = await discover_langchain_tools()
650+
normalized_tool_params = _normalize_tool_params_request(tool_params)
577651

578652
# now only admin can modify the agent, user_id is not used
579653
tools_list = search_tools_for_sub_agent(agent_id, tenant_id, version_no=version_no)
654+
655+
# Look up agent name for use in error messages.
656+
# Agent name is optional for tool_params matching (matching uses tool identifiers only),
657+
# but we include it in error messages so callers can identify which agent/tool caused a failure.
658+
agent_info = search_agent_info_by_agent_id(agent_id=agent_id, tenant_id=tenant_id, version_no=version_no)
659+
agent_name = agent_info.get("name") if agent_info else None
660+
agent_tool_overrides = _get_agent_tool_overrides(normalized_tool_params, agent_name)
661+
662+
tool_keys_seen = set()
580663
for tool in tools_list:
581-
param_dict = {}
582-
for param in tool.get("params", []):
583-
param_dict[param["name"]] = param.get("default")
664+
tool_identifier = tool.get("name") or tool.get("class_name")
665+
if tool_identifier in tool_keys_seen:
666+
raise ValidationError(
667+
f"Duplicate tool identifier '{tool_identifier}' found in agent '{agent_name or agent_id}'."
668+
)
669+
tool_keys_seen.add(tool_identifier)
670+
671+
override_params = None
672+
if tool.get("name") in agent_tool_overrides:
673+
override_params = agent_tool_overrides[tool.get("name")]
674+
elif tool.get("class_name") in agent_tool_overrides:
675+
override_params = agent_tool_overrides[tool.get("class_name")]
676+
677+
param_dict = _merge_tool_params(tool, override_params)
584678
tool_config = ToolConfig(
585679
class_name=tool.get("class_name"),
586680
name=tool.get("name"),
@@ -599,20 +693,29 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
599693
tool_config.metadata = langchain_tool
600694
break
601695

696+
# Extract document_paths for KnowledgeBaseSearchTool (internal access control, not in DB schema)
697+
document_paths = None
698+
if override_params and "document_paths" in override_params:
699+
document_paths = override_params.get("document_paths")
700+
# Also check using the tool name as key
701+
if not document_paths:
702+
kb_overrides = agent_tool_overrides.get("knowledge_base_search")
703+
if kb_overrides and "document_paths" in kb_overrides:
704+
document_paths = kb_overrides.get("document_paths")
705+
602706
# special logic for search tools that may use reranking models
603707
if tool_config.class_name == "KnowledgeBaseSearchTool":
604-
rerank = param_dict.get("rerank", False)
605-
rerank_model_name = param_dict.get("rerank_model_name", "")
708+
rerank = tool_config.params.get("rerank", False)
709+
rerank_model_name = tool_config.params.get("rerank_model_name", "")
606710
rerank_model = None
607-
is_multimodal = bool(tool_config.params.pop("multimodal", False))
608711
if rerank and rerank_model_name:
609712
rerank_model = get_rerank_model(
610713
tenant_id=tenant_id, model_name=rerank_model_name
611714
)
612715

613716
# Build display_name to index_name mapping for LLM parameter conversion
614717
# Also build reverse mapping (index_name -> display_name) for knowledge_base_summary
615-
index_names = param_dict.get("index_names", [])
718+
index_names = tool_config.params.get("index_names", [])
616719
display_name_to_index_map = {}
617720
index_name_to_display_map = {}
618721
if index_names:
@@ -628,12 +731,14 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
628731
"rerank_model": rerank_model,
629732
"display_name_to_index_map": display_name_to_index_map,
630733
"index_name_to_display_map": index_name_to_display_map,
734+
# Internal access control: restrict results to specific document paths (path_or_urls)
735+
"document_paths": document_paths,
631736
}
632737

633-
# Must have embedding model for knowledge base search
634738
if not index_names:
635739
raise ValidationError(
636-
"Embedding model is required for knowledge_base_search but index_names is empty")
740+
f"[{agent_name or agent_id}] knowledge_base_search tool requires index_names, "
741+
f"but it is not configured in the agent and not provided via tool_params.")
637742

638743
embedding_model, _, _ = get_embedding_model_by_index_name(tenant_id, index_names[0])
639744
if not embedding_model:
@@ -642,8 +747,8 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int
642747
f"Please configure an embedding model for this knowledge base.")
643748
tool_config.metadata["embedding_model"] = embedding_model
644749
elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]:
645-
rerank = param_dict.get("rerank", False)
646-
rerank_model_name = param_dict.get("rerank_model_name", "")
750+
rerank = tool_config.params.get("rerank", False)
751+
rerank_model_name = tool_config.params.get("rerank_model_name", "")
647752
rerank_model = None
648753
if rerank and rerank_model_name:
649754
rerank_model = get_rerank_model(
@@ -937,6 +1042,7 @@ async def create_agent_run_info(
9371042
is_debug: bool = False,
9381043
override_version_no: int | None = None,
9391044
override_model_id: int | None = None,
1045+
tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None,
9401046
):
9411047
# Determine which version_no to use based on is_debug flag
9421048
# If is_debug=false, use the current published version (current_version_no)
@@ -969,7 +1075,7 @@ async def create_agent_run_info(
9691075
if override_model_id is not None:
9701076
create_config_kwargs["override_model_id"] = override_model_id
9711077

972-
agent_config = await create_agent_config(**create_config_kwargs)
1078+
agent_config = await create_agent_config(**create_config_kwargs, tool_params=tool_params)
9731079

9741080
remote_mcp_list = await get_remote_mcp_server_list(tenant_id=tenant_id, is_need_auth=True)
9751081
default_mcp_url = urljoin(LOCAL_MCP_SERVER, "sse")

0 commit comments

Comments
 (0)