Skip to content

Commit 5e8b5b4

Browse files
committed
feat: AI node supports workflow calling tools
1 parent d82fa99 commit 5e8b5b4

File tree

10 files changed

+169
-24
lines changed

10 files changed

+169
-24
lines changed

apps/application/flow/i_step_node.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def get_tool_workflow_state(workflow):
128128
return State.SUCCESS
129129

130130

131+
class ToolWorkflowCallPostHandler(WorkFlowPostHandler):
132+
def __init__(self, chat_info, tool_id):
133+
super().__init__(chat_info)
134+
self.tool_id = tool_id
135+
136+
def handler(self, workflow):
137+
self.chat_info = None
138+
self.tool_id = None
139+
140+
131141
class ToolWorkflowPostHandler(WorkFlowPostHandler):
132142
def __init__(self, chat_info, tool_id):
133143
super().__init__(chat_info)

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,123 @@
99
import json
1010
import re
1111
import time
12+
import uuid
1213
from functools import reduce
1314
from typing import List, Dict
1415

15-
from application.flow.i_step_node import NodeResult, INode
16+
from langchain_core.tools import StructuredTool
17+
18+
from application.flow.common import Workflow, WorkflowMode
19+
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler, ToolWorkflowCallPostHandler
1620
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
1721
from application.flow.tools import Reasoning, mcp_response_generator
1822
from application.models import Application, ApplicationApiKey, ApplicationAccessToken
23+
from application.serializers.common import ToolExecute
1924
from common.exception.app_exception import AppApiException
2025
from common.utils.rsa_util import rsa_long_decrypt
2126
from common.utils.shared_resource_auth import filter_authorized_ids
2227
from common.utils.tool_code import ToolExecutor
23-
from django.db.models import QuerySet
28+
from django.db.models import QuerySet, OuterRef, Subquery
2429
from django.utils.translation import gettext as _
2530
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
2631
from models_provider.models import Model
2732
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
28-
from tools.models import Tool
33+
from tools.models import Tool, ToolWorkflowVersion, ToolType
34+
from pydantic import BaseModel, Field, create_model
35+
import uuid_utils.compat as uuid
36+
37+
38+
def build_schema(fields: dict):
39+
return create_model("dynamicSchema", **fields)
40+
41+
42+
def get_type(_type: str):
43+
if _type == 'float':
44+
return float
45+
if _type == 'string':
46+
return str
47+
if _type == 'int':
48+
return int
49+
if _type == 'dict':
50+
return dict
51+
if _type == 'array':
52+
return list
53+
if _type == 'boolean':
54+
return bool
55+
return object
56+
57+
58+
def get_workflow_args(tool, qv):
59+
for node in qv.work_flow.get('nodes'):
60+
if node.get('type') == 'tool-base-node':
61+
input_field_list = node.get('properties').get('user_input_field_list')
62+
return build_schema(
63+
{field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
64+
for field in input_field_list})
65+
66+
return build_schema({})
67+
68+
69+
def get_workflow_func(tool, qv, workspace_id):
70+
tool_id = tool.id
71+
tool_record_id = str(uuid.uuid7())
72+
took_execute = ToolExecute(tool_id, tool_record_id,
73+
workspace_id,
74+
None,
75+
None,
76+
True)
77+
78+
def inner(**kwargs):
79+
from application.flow.tool_workflow_manage import ToolWorkflowManage
80+
work_flow_manage = ToolWorkflowManage(
81+
Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
82+
{
83+
'chat_record_id': tool_record_id,
84+
'tool_id': tool_id,
85+
'stream': True,
86+
'workspace_id': workspace_id,
87+
**kwargs},
88+
89+
ToolWorkflowCallPostHandler(took_execute, tool_id),
90+
is_the_task_interrupted=lambda: False,
91+
child_node=None,
92+
start_node_id=None,
93+
start_node_data=None,
94+
chat_record=None
95+
)
96+
res = work_flow_manage.run()
97+
for r in res:
98+
pass
99+
return work_flow_manage.out_context
100+
101+
return inner
102+
103+
104+
def get_tools(tool_workflow_ids, workspace_id):
105+
tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id)
106+
latest_subquery = ToolWorkflowVersion.objects.filter(
107+
tool_id=OuterRef('tool_id')
108+
).order_by('-create_time')
109+
110+
qs = ToolWorkflowVersion.objects.filter(
111+
tool_id__in=[t.id for t in tools],
112+
id=Subquery(latest_subquery.values('id')[:1])
113+
)
114+
qd = {q.tool_id: q for q in qs}
115+
results = []
116+
for tool in tools:
117+
qv = qd.get(tool.id)
118+
func = get_workflow_func(tool, qv, workspace_id)
119+
args = get_workflow_args(tool, qv)
120+
tool = StructuredTool.from_function(
121+
func=func,
122+
name=tool.name,
123+
description=tool.desc,
124+
args_schema=args,
125+
)
126+
results.append(tool)
127+
128+
return results
29129

30130

31131
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
@@ -178,7 +278,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
178278
model_id = reference_data.get('model_id', model_id)
179279
model_params_setting = reference_data.get('model_params_setting')
180280

181-
if model_params_setting is None and model_id:
281+
if model_params_setting is None and model_id:
182282
model_params_setting = get_default_model_params_setting(model_id)
183283

184284
if model_setting is None:
@@ -216,7 +316,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
216316
mcp_result = self._handle_mcp_request(
217317
mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids,
218318
application_ids, skill_tool_ids, mcp_output_enable,
219-
chat_model, message_list, history_message, question, chat_id
319+
chat_model, message_list, history_message, question, chat_id, workspace_id
220320
)
221321
if mcp_result:
222322
return mcp_result
@@ -236,7 +336,8 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
236336

237337
def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids,
238338
application_ids, skill_tool_ids,
239-
mcp_output_enable, chat_model, message_list, history_message, question, chat_id):
339+
mcp_output_enable, chat_model, message_list, history_message, question, chat_id,
340+
workspace_id):
240341

241342
mcp_servers_config = {}
242343

@@ -259,11 +360,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
259360
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
260361
mcp_servers_config = self.handle_variables(mcp_servers_config)
261362
tool_init_params = {}
363+
tools = get_tools(tool_ids, workspace_id)
262364
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
263365
self.context['tool_ids'] = tool_ids
264366
for tool_id in tool_ids:
265-
tool = QuerySet(Tool).filter(id=tool_id).first()
266-
if not tool.is_active:
367+
tool = QuerySet(Tool).filter(id=tool_id, tool_type=ToolType.CUSTOM).first()
368+
if tool is None or not tool.is_active:
267369
continue
268370
executor = ToolExecutor()
269371
if tool.init_params is not None:
@@ -323,7 +425,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
323425
})
324426
mcp_servers_config['skills'] = skill_file_items
325427

326-
if len(mcp_servers_config) > 0:
428+
if len(mcp_servers_config) > 0 or len(tools) > 0:
327429
# 安全获取 application
328430
application_id = None
329431
if (self.workflow_manage and
@@ -334,7 +436,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
334436
source_id = application_id or knowledge_id
335437
source_type = 'APPLICATION' if application_id else 'KNOWLEDGE'
336438
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable,
337-
tool_init_params, source_id, source_type, chat_id)
439+
tool_init_params, source_id, source_type, chat_id, tools)
338440
return NodeResult(
339441
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
340442
'history_message': [{'content': message.content, 'role': message.type} for message in

apps/application/flow/tools.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _merge_lists_normalize_empty_tool_chunk_ids(left, *others):
5656
"""Wrapper around merge_lists that normalises empty-string IDs to None in
5757
tool_call_chunk items (those with an 'index' key) so that qwen streaming
5858
chunks with id='' are merged correctly by index."""
59+
5960
def _norm(lst):
6061
if lst is None:
6162
return lst
@@ -158,17 +159,17 @@ def get_reasoning_content(self, chunk):
158159
self.reasoning_content_end_tag)
159160
if reasoning_content_end_tag_index > -1:
160161
reasoning_content_chunk = self.reasoning_content_chunk[
161-
0:reasoning_content_end_tag_index]
162+
0:reasoning_content_end_tag_index]
162163
content_chunk = self.reasoning_content_chunk[
163-
reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:]
164+
reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:]
164165
self.reasoning_content += reasoning_content_chunk
165166
self.content += content_chunk
166167
self.reasoning_content_chunk = ""
167168
self.reasoning_content_is_end = True
168169
return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk}
169170
else:
170171
reasoning_content_chunk = self.reasoning_content_chunk[
171-
0:reasoning_content_end_tag_prefix_index + 1]
172+
0:reasoning_content_end_tag_prefix_index + 1]
172173
self.reasoning_content_chunk = self.reasoning_content_chunk.replace(
173174
reasoning_content_chunk, '')
174175
self.reasoning_content += reasoning_content_chunk
@@ -401,11 +402,14 @@ async def _initialize_skills(mcp_servers, temp_dir):
401402

402403

403404
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
404-
source_id=None, source_type=None, temp_dir=None, chat_id=None):
405+
source_id=None, source_type=None, temp_dir=None, chat_id=None, extra_tools=None):
405406
try:
406407
checkpointer = MemorySaver()
407408
client = await _initialize_skills(mcp_servers, temp_dir)
408409
tools = await client.get_tools()
410+
if extra_tools:
411+
for tool in extra_tools:
412+
tools.append(tool)
409413
agent = create_deep_agent(
410414
model=chat_model,
411415
backend=SandboxShellBackend(root_dir=temp_dir, virtual_mode=True),
@@ -517,7 +521,7 @@ def _upsert_fragment(key, raw_id, func_name, part_args):
517521
# qwen-plus often emits {} here as a placeholder while
518522
# the real args are split in tool_call_chunks/invalid_tool_calls.
519523
if has_tool_call_chunks and (
520-
part_args == '' or part_args == {} or part_args == []
524+
part_args == '' or part_args == {} or part_args == []
521525
):
522526
part_args = ''
523527
key = _get_fragment_key(tool_call.get('index'), raw_id)
@@ -563,9 +567,9 @@ def _upsert_fragment(key, raw_id, func_name, part_args):
563567
# 3. 检测工具调用结束,更新 tool_calls_info
564568
# ----------------------------------------------------------------
565569
is_finish_chunk = (
566-
chunk[0].response_metadata.get(
567-
'finish_reason') == 'tool_calls'
568-
or chunk[0].chunk_position == 'last'
570+
chunk[0].response_metadata.get(
571+
'finish_reason') == 'tool_calls'
572+
or chunk[0].chunk_position == 'last'
569573
)
570574

571575
if is_finish_chunk:
@@ -734,7 +738,7 @@ async def save_tool_record(tool_id, tool_info, tool_result, source_id, source_ty
734738

735739

736740
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
737-
source_id=None, source_type=None, chat_id=None):
741+
source_id=None, source_type=None, chat_id=None, extra_tools=None):
738742
"""使用全局事件循环,不创建新实例"""
739743
result_queue = queue.Queue()
740744
loop = get_global_loop() # 使用共享循环
@@ -751,7 +755,7 @@ def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_ena
751755
async def _run():
752756
try:
753757
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params,
754-
source_id, source_type, temp_dir, chat_id)
758+
source_id, source_type, temp_dir, chat_id, extra_tools)
755759
async for chunk in async_gen:
756760
result_queue.put(('data', chunk))
757761
except Exception as e:

ui/src/locales/lang/en-US/dynamics-form.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ export default {
5151
placeholder: 'Please select a type',
5252
requiredMessage: 'Type is a required property',
5353
},
54+
desc: {
55+
label: 'description',
56+
placeholder: 'Please enter a description',
57+
},
5458
},
5559
DatePicker: {
5660
placeholder: 'Select Date',

ui/src/locales/lang/zh-CN/dynamics-form.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ export default {
5151
placeholder: '请选择组件类型',
5252
requiredMessage: '组建类型 为必填属性',
5353
},
54+
desc: {
55+
label: '描述',
56+
placeholder: '请输入描述',
57+
},
5458
},
5559
DatePicker: {
5660
placeholder: '选择日期',

ui/src/locales/lang/zh-Hant/dynamics-form.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ export default {
5151
placeholder: '請選擇組件類型',
5252
requiredMessage: '組件類型 為必填屬性',
5353
},
54+
desc: {
55+
label: '描述',
56+
placeholder: '請輸入描述',
57+
},
5458
},
5559
DatePicker: {
5660
placeholder: '選擇日期',

ui/src/views/application/component/ToolDialog.vue

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,20 @@ function getFolder() {
279279
280280
function getList() {
281281
const folder_id = currentFolder.value?.id || user.getWorkspaceId()
282+
const query: any = {}
283+
if (props.tool_type.includes(',')) {
284+
query['tool_type_list'] = props.tool_type.split(',')
285+
} else {
286+
query['tool_type'] = props.tool_type
287+
}
282288
loadSharedApi({
283289
type: 'tool',
284290
isShared: folder_id === 'share',
285291
systemType: apiType.value,
286292
})
287293
.getToolList({
288294
folder_id: folder_id,
289-
tool_type: props.tool_type,
295+
...query,
290296
})
291297
.then((res: any) => {
292298
toolList.value = res.data?.tools || res.data || []

ui/src/workflow/nodes/ai-chat-node/index.vue

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@
488488
@refresh="submitReasoningDialog"
489489
/>
490490
<McpServersDialog ref="mcpServersDialogRef" @refresh="submitMcpServersDialog" />
491-
<ToolDialog ref="toolDialogRef" @refresh="submitToolDialog" tool_type="CUSTOM" />
491+
<ToolDialog ref="toolDialogRef" @refresh="submitToolDialog" tool_type="CUSTOM,WORKFLOW" />
492492
<ToolDialog ref="skillToolDialogRef" @refresh="submitSkillToolDialog" tool_type="SKILL" />
493493
<ApplicationDialog ref="applicationDialogRef" @refresh="submitApplicationDialog" />
494494
</NodeContainer>
@@ -724,12 +724,12 @@ function getToolSelectOptions() {
724724
apiType.value === 'systemManage'
725725
? {
726726
scope: 'WORKSPACE',
727-
tool_type: 'CUSTOM',
727+
tool_type_list: ['CUSTOM', 'WORKFLOW'],
728728
workspace_id: resource.value?.workspace_id,
729729
}
730730
: {
731731
scope: 'WORKSPACE',
732-
tool_type: 'CUSTOM',
732+
tool_type_list: ['CUSTOM', 'WORKFLOW'],
733733
}
734734
735735
loadSharedApi({ type: 'tool', systemType: apiType.value })

ui/src/workflow/nodes/tool-base-node/component/input/InputFieldFormDialog.vue

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@
3232
@blur="form.label = form.label?.trim()"
3333
/>
3434
</el-form-item>
35+
<el-form-item :label="$t('dynamicsForm.paramForm.desc.label')">
36+
<el-input
37+
v-model="form.desc"
38+
:placeholder="$t('dynamicsForm.paramForm.desc.placeholder')"
39+
:maxlength="128"
40+
show-word-limit
41+
@blur="form.label = form.desc?.trim()"
42+
/>
43+
</el-form-item>
3544
<el-form-item :label="$t('views.tool.form.dataType.label')">
3645
<el-select v-model="form.type">
3746
<el-option v-for="item in typeOptions" :key="item" :label="item" :value="item" />
@@ -66,6 +75,7 @@ const form = ref<any>({
6675
field: '',
6776
type: typeOptions[0],
6877
label: '',
78+
desc: '',
6979
is_required: true,
7080
})
7181

0 commit comments

Comments
 (0)