Skip to content

Commit 4bf4e61

Browse files
authored
fix: [Simple Agent] In simple agents, workflow work is used as a skill, and no specified skill is called during dialogue (#5043)
1 parent aa3d9e1 commit 4bf4e61

File tree

3 files changed

+120
-117
lines changed

3 files changed

+120
-117
lines changed

apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
2323
from application.chat_pipeline.pipeline_manage import PipelineManage
2424
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
25-
from application.flow.tools import Reasoning, mcp_response_generator
25+
from application.flow.tools import Reasoning, mcp_response_generator, get_tools
2626
from application.models import ApplicationChatUserStats, ChatUserType, Application, ApplicationApiKey, \
2727
ApplicationAccessToken
2828
from common.exception.app_exception import AppApiException
@@ -31,7 +31,7 @@
3131
from common.utils.shared_resource_auth import filter_authorized_ids
3232
from common.utils.tool_code import ToolExecutor
3333
from models_provider.tools import get_model_instance_by_model_workspace_id
34-
from tools.models import Tool
34+
from tools.models import Tool, ToolType
3535

3636

3737
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
@@ -232,7 +232,7 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):
232232

233233
def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
234234
application_ids, skill_tool_ids, mcp_output_enable, chat_model, message_list, agent_id,
235-
chat_id):
235+
chat_id, workspace_id):
236236

237237
mcp_servers_config = {}
238238

@@ -252,10 +252,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
252252
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
253253

254254
tool_init_params = {}
255+
tools = get_tools("APPLICATION", agent_id, tool_ids,
256+
workspace_id)
255257
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
256258
self.context['tool_ids'] = tool_ids
257259
for tool_id in tool_ids:
258-
tool = QuerySet(Tool).filter(id=tool_id).first()
260+
tool = QuerySet(Tool).filter(id=tool_id, tool_type=ToolType.CUSTOM).first()
259261
if tool is None or tool.is_active is False:
260262
continue
261263
executor = ToolExecutor()
@@ -316,12 +318,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
316318
})
317319
mcp_servers_config['skills'] = skill_file_items
318320

319-
if len(mcp_servers_config) > 0:
321+
if len(mcp_servers_config) > 0 or len(tools) > 0:
320322
source_id = agent_id
321323
source_type = 'APPLICATION'
322324
return mcp_response_generator(
323325
chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable,
324-
tool_init_params, source_id, source_type, chat_id
326+
tool_init_params, source_id, source_type, chat_id, tools
325327
)
326328

327329
return None
@@ -372,7 +374,7 @@ def get_stream_result(self, message_list: List[BaseMessage],
372374
mcp_result = self._handle_mcp_request(
373375
mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
374376
application_ids, skill_tool_ids, mcp_output_enable, chat_model,
375-
message_list, agent_id, chat_id
377+
message_list, agent_id, chat_id, workspace_id
376378
)
377379
if mcp_result:
378380
return mcp_result, True
@@ -461,7 +463,7 @@ def get_block_result(self, message_list: List[BaseMessage],
461463
mcp_result = self._handle_mcp_request(
462464
mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
463465
application_ids, skill_tool_ids, mcp_output_enable,
464-
chat_model, message_list, application_id, chat_id
466+
chat_model, message_list, application_id, chat_id, workspace_id
465467
)
466468
if mcp_result:
467469
return mcp_result, True
@@ -496,7 +498,7 @@ def execute_block(self, message_list: List[BaseMessage],
496498
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
497499
no_references_setting, problem_text,
498500
mcp_tool_ids, mcp_servers, mcp_source,
499-
tool_ids, application_ids, skill_tool_ids,workspace_id,
501+
tool_ids, application_ids, skill_tool_ids, workspace_id,
500502
mcp_output_enable, manage.context.get('application_id'),
501503
chat_id)
502504
if is_ai_chat:

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 7 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -9,123 +9,25 @@
99
import json
1010
import re
1111
import time
12-
import uuid
1312
from functools import reduce
1413
from typing import List, Dict
1514

16-
import uuid_utils.compat as uuid
17-
from django.db.models import QuerySet, OuterRef, Subquery
15+
from django.db.models import QuerySet
1816
from django.utils.translation import gettext as _
1917
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
20-
from langchain_core.tools import StructuredTool
21-
from pydantic import Field, create_model
2218

23-
from application.flow.common import Workflow, WorkflowMode
24-
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler
19+
from application.flow.common import WorkflowMode
20+
from application.flow.i_step_node import NodeResult, INode
2521
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
26-
from application.flow.tools import Reasoning, mcp_response_generator
22+
from application.flow.tools import Reasoning, mcp_response_generator, get_tools
2723
from application.models import Application, ApplicationApiKey, ApplicationAccessToken
28-
from application.serializers.common import ToolExecute
2924
from common.exception.app_exception import AppApiException
3025
from common.utils.rsa_util import rsa_long_decrypt
3126
from common.utils.shared_resource_auth import filter_authorized_ids
3227
from common.utils.tool_code import ToolExecutor
3328
from models_provider.models import Model
3429
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
35-
from tools.models import Tool, ToolWorkflowVersion, ToolType
36-
37-
38-
def build_schema(fields: dict):
39-
return create_model("dynamicSchema", **fields)
40-
41-
42-
def get_type(_type: str):
43-
if _type == 'float':
44-
return float
45-
if _type == 'string':
46-
return str
47-
if _type == 'int':
48-
return int
49-
if _type == 'dict':
50-
return dict
51-
if _type == 'array':
52-
return list
53-
if _type == 'boolean':
54-
return bool
55-
return object
56-
57-
58-
def get_workflow_args(tool, qv):
59-
for node in qv.work_flow.get('nodes'):
60-
if node.get('type') == 'tool-base-node':
61-
input_field_list = node.get('properties').get('user_input_field_list')
62-
return build_schema(
63-
{field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
64-
for field in input_field_list})
65-
66-
return build_schema({})
67-
68-
69-
def get_workflow_func(node, tool, qv, workspace_id):
70-
tool_id = tool.id
71-
tool_record_id = str(uuid.uuid7())
72-
took_execute = ToolExecute(tool_id, tool_record_id,
73-
workspace_id,
74-
node.workflow_manage.get_source_type(),
75-
node.workflow_manage.get_source_id(),
76-
False)
77-
78-
def inner(**kwargs):
79-
from application.flow.tool_workflow_manage import ToolWorkflowManage
80-
work_flow_manage = ToolWorkflowManage(
81-
Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
82-
{
83-
'chat_record_id': tool_record_id,
84-
'tool_id': tool_id,
85-
'stream': True,
86-
'workspace_id': workspace_id,
87-
**kwargs},
88-
89-
ToolWorkflowPostHandler(took_execute, tool_id),
90-
is_the_task_interrupted=lambda: False,
91-
child_node=None,
92-
start_node_id=None,
93-
start_node_data=None,
94-
chat_record=None
95-
)
96-
res = work_flow_manage.run()
97-
for r in res:
98-
pass
99-
return work_flow_manage.out_context
100-
101-
return inner
102-
103-
104-
def get_tools(node, tool_workflow_ids, workspace_id):
105-
tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id)
106-
latest_subquery = ToolWorkflowVersion.objects.filter(
107-
tool_id=OuterRef('tool_id')
108-
).order_by('-create_time')
109-
110-
qs = ToolWorkflowVersion.objects.filter(
111-
tool_id__in=[t.id for t in tools],
112-
id=Subquery(latest_subquery.values('id')[:1])
113-
)
114-
qd = {q.tool_id: q for q in qs}
115-
results = []
116-
for tool in tools:
117-
qv = qd.get(tool.id)
118-
func = get_workflow_func(node, tool, qv, workspace_id)
119-
args = get_workflow_args(tool, qv)
120-
tool = StructuredTool.from_function(
121-
func=func,
122-
name=tool.name,
123-
description=tool.desc,
124-
args_schema=args,
125-
)
126-
results.append(tool)
127-
128-
return results
30+
from tools.models import Tool, ToolType
12931

13032

13133
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
@@ -362,7 +264,8 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
362264
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
363265
mcp_servers_config = self.handle_variables(mcp_servers_config)
364266
tool_init_params = {}
365-
tools = get_tools(self, tool_ids, workspace_id)
267+
tools = get_tools(self.workflow_manage.get_source_type(), self.workflow_manage.get_source_id(), tool_ids,
268+
workspace_id)
366269
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
367270
self.context['tool_ids'] = tool_ids
368271
for tool_id in tool_ids:

apps/application/flow/tools.py

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
@date:2024/6/6 15:15
77
@desc:
88
"""
9-
from tools.models import ToolRecord, Tool, ToolScope
9+
from langchain_core.tools import StructuredTool
10+
11+
from application.flow.common import Workflow, WorkflowMode
12+
from application.serializers.common import ToolExecute
13+
from tools.models import ToolRecord, Tool, ToolScope, ToolWorkflowVersion, ToolType
1014
from maxkb.const import CONFIG
1115
from knowledge.models.knowledge_action import State
1216
from knowledge.models import File
1317
from common.utils.logger import maxkb_logger
1418
from common.result import result
15-
from application.flow.i_step_node import WorkFlowPostHandler
19+
from application.flow.i_step_node import WorkFlowPostHandler, ToolWorkflowPostHandler
1620
from application.flow.backend.sandbox_shell import SandboxShellBackend
1721
import asyncio
1822
import io
@@ -25,11 +29,11 @@
2529
import zipfile
2630
from functools import reduce
2731
from typing import Iterator
28-
32+
from pydantic import Field, create_model
2933
import uuid_utils.compat as uuid
3034
from asgiref.sync import sync_to_async
3135
from deepagents import create_deep_agent
32-
from django.db.models import QuerySet
36+
from django.db.models import QuerySet, OuterRef, Subquery
3337
from django.http import StreamingHttpResponse
3438
from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk, SystemMessage
3539
from langchain_mcp_adapters.client import MultiServerMCPClient
@@ -990,3 +994,97 @@ def get_child_tool_id_list(work_flow, response):
990994
for tool in tool_list:
991995
response.append(str(tool.id))
992996
return response
997+
998+
999+
def build_schema(fields: dict):
1000+
return create_model("dynamicSchema", **fields)
1001+
1002+
1003+
def get_type(_type: str):
1004+
if _type == 'float':
1005+
return float
1006+
if _type == 'string':
1007+
return str
1008+
if _type == 'int':
1009+
return int
1010+
if _type == 'dict':
1011+
return dict
1012+
if _type == 'array':
1013+
return list
1014+
if _type == 'boolean':
1015+
return bool
1016+
return object
1017+
1018+
1019+
def get_workflow_args(tool, qv):
1020+
for node in qv.work_flow.get('nodes'):
1021+
if node.get('type') == 'tool-base-node':
1022+
input_field_list = node.get('properties').get('user_input_field_list')
1023+
return build_schema(
1024+
{field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
1025+
for field in input_field_list})
1026+
1027+
return build_schema({})
1028+
1029+
1030+
def get_workflow_func(source_type, source_id, tool, qv, workspace_id):
1031+
tool_id = tool.id
1032+
tool_record_id = str(uuid.uuid7())
1033+
took_execute = ToolExecute(tool_id, tool_record_id,
1034+
workspace_id,
1035+
source_type,
1036+
source_id,
1037+
False)
1038+
1039+
def inner(**kwargs):
1040+
from application.flow.tool_workflow_manage import ToolWorkflowManage
1041+
work_flow_manage = ToolWorkflowManage(
1042+
Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
1043+
{
1044+
'chat_record_id': tool_record_id,
1045+
'tool_id': tool_id,
1046+
'stream': True,
1047+
'workspace_id': workspace_id,
1048+
**kwargs},
1049+
1050+
ToolWorkflowPostHandler(took_execute, tool_id),
1051+
is_the_task_interrupted=lambda: False,
1052+
child_node=None,
1053+
start_node_id=None,
1054+
start_node_data=None,
1055+
chat_record=None
1056+
)
1057+
res = work_flow_manage.run()
1058+
for r in res:
1059+
pass
1060+
return work_flow_manage.out_context
1061+
1062+
return inner
1063+
1064+
1065+
def get_tools(source_type, source_id, tool_workflow_ids, workspace_id):
1066+
tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id)
1067+
latest_subquery = ToolWorkflowVersion.objects.filter(
1068+
tool_id=OuterRef('tool_id')
1069+
).order_by('-create_time')
1070+
1071+
qs = ToolWorkflowVersion.objects.filter(
1072+
tool_id__in=[t.id for t in tools],
1073+
id=Subquery(latest_subquery.values('id')[:1])
1074+
)
1075+
qd = {q.tool_id: q for q in qs}
1076+
results = []
1077+
for tool in tools:
1078+
qv = qd.get(tool.id)
1079+
func = get_workflow_func(source_type, source_id, tool, qv,
1080+
workspace_id)
1081+
args = get_workflow_args(tool, qv)
1082+
tool = StructuredTool.from_function(
1083+
func=func,
1084+
name=tool.name,
1085+
description=tool.desc,
1086+
args_schema=args,
1087+
)
1088+
results.append(tool)
1089+
1090+
return results

0 commit comments

Comments
 (0)