Skip to content

Commit 078bda5

Browse files
Merge remote-tracking branch 'upstream/v2' into perf
2 parents b3e7b65 + 8704b22 commit 078bda5

25 files changed

Lines changed: 470 additions & 119 deletions

File tree

apps/application/flow/i_step_node.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def get_tool_workflow_state(workflow):
128128
return State.SUCCESS
129129

130130

131+
class ToolWorkflowCallPostHandler(WorkFlowPostHandler):
132+
def __init__(self, chat_info, tool_id):
133+
super().__init__(chat_info)
134+
self.tool_id = tool_id
135+
136+
def handler(self, workflow):
137+
self.chat_info = None
138+
self.tool_id = None
139+
140+
131141
class ToolWorkflowPostHandler(WorkFlowPostHandler):
132142
def __init__(self, chat_info, tool_id):
133143
super().__init__(chat_info)

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 113 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,123 @@
99
import json
1010
import re
1111
import time
12+
import uuid
1213
from functools import reduce
1314
from typing import List, Dict
1415

15-
from application.flow.i_step_node import NodeResult, INode
16+
from langchain_core.tools import StructuredTool
17+
18+
from application.flow.common import Workflow, WorkflowMode
19+
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler, ToolWorkflowCallPostHandler
1620
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
1721
from application.flow.tools import Reasoning, mcp_response_generator
1822
from application.models import Application, ApplicationApiKey, ApplicationAccessToken
23+
from application.serializers.common import ToolExecute
1924
from common.exception.app_exception import AppApiException
2025
from common.utils.rsa_util import rsa_long_decrypt
2126
from common.utils.shared_resource_auth import filter_authorized_ids
2227
from common.utils.tool_code import ToolExecutor
23-
from django.db.models import QuerySet
28+
from django.db.models import QuerySet, OuterRef, Subquery
2429
from django.utils.translation import gettext as _
2530
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
2631
from models_provider.models import Model
2732
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
28-
from tools.models import Tool
33+
from tools.models import Tool, ToolWorkflowVersion, ToolType
34+
from pydantic import BaseModel, Field, create_model
35+
import uuid_utils.compat as uuid
36+
37+
38+
def build_schema(fields: dict):
39+
return create_model("dynamicSchema", **fields)
40+
41+
42+
def get_type(_type: str):
43+
if _type == 'float':
44+
return float
45+
if _type == 'string':
46+
return str
47+
if _type == 'int':
48+
return int
49+
if _type == 'dict':
50+
return dict
51+
if _type == 'array':
52+
return list
53+
if _type == 'boolean':
54+
return bool
55+
return object
56+
57+
58+
def get_workflow_args(tool, qv):
59+
for node in qv.work_flow.get('nodes'):
60+
if node.get('type') == 'tool-base-node':
61+
input_field_list = node.get('properties').get('user_input_field_list')
62+
return build_schema(
63+
{field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
64+
for field in input_field_list})
65+
66+
return build_schema({})
67+
68+
69+
def get_workflow_func(tool, qv, workspace_id):
70+
tool_id = tool.id
71+
tool_record_id = str(uuid.uuid7())
72+
took_execute = ToolExecute(tool_id, tool_record_id,
73+
workspace_id,
74+
None,
75+
None,
76+
True)
77+
78+
def inner(**kwargs):
79+
from application.flow.tool_workflow_manage import ToolWorkflowManage
80+
work_flow_manage = ToolWorkflowManage(
81+
Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
82+
{
83+
'chat_record_id': tool_record_id,
84+
'tool_id': tool_id,
85+
'stream': True,
86+
'workspace_id': workspace_id,
87+
**kwargs},
88+
89+
ToolWorkflowCallPostHandler(took_execute, tool_id),
90+
is_the_task_interrupted=lambda: False,
91+
child_node=None,
92+
start_node_id=None,
93+
start_node_data=None,
94+
chat_record=None
95+
)
96+
res = work_flow_manage.run()
97+
for r in res:
98+
pass
99+
return work_flow_manage.out_context
100+
101+
return inner
102+
103+
104+
def get_tools(tool_workflow_ids, workspace_id):
105+
tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id)
106+
latest_subquery = ToolWorkflowVersion.objects.filter(
107+
tool_id=OuterRef('tool_id')
108+
).order_by('-create_time')
109+
110+
qs = ToolWorkflowVersion.objects.filter(
111+
tool_id__in=[t.id for t in tools],
112+
id=Subquery(latest_subquery.values('id')[:1])
113+
)
114+
qd = {q.tool_id: q for q in qs}
115+
results = []
116+
for tool in tools:
117+
qv = qd.get(tool.id)
118+
func = get_workflow_func(tool, qv, workspace_id)
119+
args = get_workflow_args(tool, qv)
120+
tool = StructuredTool.from_function(
121+
func=func,
122+
name=tool.name,
123+
description=tool.desc,
124+
args_schema=args,
125+
)
126+
results.append(tool)
127+
128+
return results
29129

30130

31131
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
@@ -178,7 +278,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
178278
model_id = reference_data.get('model_id', model_id)
179279
model_params_setting = reference_data.get('model_params_setting')
180280

181-
if model_params_setting is None and model_id:
281+
if model_params_setting is None and model_id:
182282
model_params_setting = get_default_model_params_setting(model_id)
183283

184284
if model_setting is None:
@@ -187,7 +287,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
187287
self.context['model_setting'] = model_setting
188288
workspace_id = self.workflow_manage.get_body().get('workspace_id')
189289
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
190-
**model_params_setting)
290+
**(model_params_setting or {}))
191291
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
192292
self.runtime_node_id)
193293
self.context['history_message'] = [{'content': message.content, 'role': message.type} for message in
@@ -216,7 +316,7 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
216316
mcp_result = self._handle_mcp_request(
217317
mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids,
218318
application_ids, skill_tool_ids, mcp_output_enable,
219-
chat_model, message_list, history_message, question, chat_id
319+
chat_model, message_list, history_message, question, chat_id, workspace_id
220320
)
221321
if mcp_result:
222322
return mcp_result
@@ -236,7 +336,8 @@ def execute(self, model_id, system, prompt, dialogue_number, history_chat_record
236336

237337
def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids, tool_ids,
238338
application_ids, skill_tool_ids,
239-
mcp_output_enable, chat_model, message_list, history_message, question, chat_id):
339+
mcp_output_enable, chat_model, message_list, history_message, question, chat_id,
340+
workspace_id):
240341

241342
mcp_servers_config = {}
242343

@@ -259,11 +360,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
259360
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
260361
mcp_servers_config = self.handle_variables(mcp_servers_config)
261362
tool_init_params = {}
363+
tools = get_tools(tool_ids, workspace_id)
262364
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
263365
self.context['tool_ids'] = tool_ids
264366
for tool_id in tool_ids:
265-
tool = QuerySet(Tool).filter(id=tool_id).first()
266-
if not tool.is_active:
367+
tool = QuerySet(Tool).filter(id=tool_id, tool_type=ToolType.CUSTOM).first()
368+
if tool is None or not tool.is_active:
267369
continue
268370
executor = ToolExecutor()
269371
if tool.init_params is not None:
@@ -323,7 +425,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
323425
})
324426
mcp_servers_config['skills'] = skill_file_items
325427

326-
if len(mcp_servers_config) > 0:
428+
if len(mcp_servers_config) > 0 or len(tools) > 0:
327429
# 安全获取 application
328430
application_id = None
329431
if (self.workflow_manage and
@@ -334,7 +436,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
334436
source_id = application_id or knowledge_id
335437
source_type = 'APPLICATION' if application_id else 'KNOWLEDGE'
336438
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable,
337-
tool_init_params, source_id, source_type, chat_id)
439+
tool_init_params, source_id, source_type, chat_id, tools)
338440
return NodeResult(
339441
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
340442
'history_message': [{'content': message.content, 'role': message.type} for message in

apps/application/flow/step_node/image_to_video_step_node/i_image_to_video_node.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010

1111

1212
class ImageToVideoNodeSerializer(serializers.Serializer):
13-
model_id = serializers.CharField(required=True, label=_("Model id"))
13+
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
14+
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
15+
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
16+
label=_("Reference Field"))
1417

1518
prompt = serializers.CharField(required=True, label=_("Prompt word (positive)"))
1619

@@ -69,5 +72,6 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
6972
model_params_setting,
7073
chat_record_id,
7174
first_frame_url, last_frame_url,
75+
model_id_type=None, model_id_reference=None,
7276
**kwargs) -> NodeResult:
7377
pass

apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,21 @@ def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_t
2929
model_params_setting,
3030
chat_record_id,
3131
first_frame_url, last_frame_url=None,
32+
model_id_type=None, model_id_reference=None,
3233
**kwargs) -> NodeResult:
34+
# 处理引用类型
35+
if model_id_type == 'reference' and model_id_reference:
36+
reference_data = self.workflow_manage.get_reference_field(
37+
model_id_reference[0],
38+
model_id_reference[1:],
39+
)
40+
if reference_data and isinstance(reference_data, dict):
41+
model_id = reference_data.get('model_id', model_id)
42+
model_params_setting = reference_data.get('model_params_setting')
43+
3344
workspace_id = self.workflow_manage.get_body().get('workspace_id')
3445
ttv_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
35-
**model_params_setting)
46+
**(model_params_setting or {}))
3647
history_message = self.get_history_message(history_chat_record, dialogue_number)
3748
self.context['history_message'] = history_message
3849
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/intent_node/impl/base_intent_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def execute(self, model_id, dialogue_number, history_chat_record, user_input, br
7171
# 获取模型实例
7272
workspace_id = self.workflow_manage.get_body().get('workspace_id')
7373
chat_model = get_model_instance_by_model_workspace_id(
74-
model_id, workspace_id, **model_params_setting
74+
model_id, workspace_id, **(model_params_setting or {})
7575
)
7676

7777
# 获取历史对话

apps/application/flow/step_node/parameter_extraction_node/i_parameter_extraction_node.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ class VariableSplittingNodeParamsSerializer(serializers.Serializer):
1919
model_params_setting = serializers.DictField(required=False,
2020
label=_("Model parameter settings"))
2121

22-
model_id = serializers.CharField(required=True, label=_("Model id"))
22+
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
23+
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
24+
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
25+
label=_("Reference Field"))
2326

2427

2528
class IParameterExtractionNode(INode):
@@ -31,12 +34,25 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
3134
return VariableSplittingNodeParamsSerializer
3235

3336
def _run(self):
37+
model_id_type = self.node_params_serializer.data.get('model_id_type')
38+
model_id_reference = self.node_params_serializer.data.get('model_id_reference')
39+
model_id = self.node_params_serializer.data.get('model_id')
40+
model_params_setting = self.node_params_serializer.data.get('model_params_setting')
41+
# 处理引用类型
42+
if model_id_type == 'reference' and model_id_reference:
43+
reference_data = self.workflow_manage.get_reference_field(
44+
model_id_reference[0],
45+
model_id_reference[1:],
46+
)
47+
if reference_data and isinstance(reference_data, dict):
48+
model_id = reference_data.get('model_id', model_id)
49+
model_params_setting = reference_data.get('model_params_setting')
50+
3451
input_variable = self.workflow_manage.get_reference_field(
3552
self.node_params_serializer.data.get('input_variable')[0],
3653
self.node_params_serializer.data.get('input_variable')[1:])
3754
return self.execute(input_variable, self.node_params_serializer.data['variable_list'],
38-
self.node_params_serializer.data['model_params_setting'],
39-
self.node_params_serializer.data['model_id'])
55+
model_params_setting, model_id)
4056

4157
def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult:
4258
pass

apps/application/flow/step_node/parameter_extraction_node/impl/base_parameter_extraction_node.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ def save_context(self, details, workflow_manage):
9494
def execute(self, input_variable, variable_list, model_params_setting, model_id, **kwargs) -> NodeResult:
9595
input_variable = str(input_variable)
9696
self.context['request'] = input_variable
97-
if model_params_setting is None:
97+
if model_params_setting is None and model_id:
9898
model_params_setting = get_default_model_params_setting(model_id)
9999
workspace_id = self.workflow_manage.get_body().get('workspace_id')
100100
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
101-
**model_params_setting)
101+
**(model_params_setting or {}))
102+
102103
content = generate_content(input_variable, variable_list)
103104
response = chat_model.invoke([HumanMessage(content=content)])
104105
result = json_loads(response.content, variable_list)

apps/application/flow/step_node/question_node/i_question_node.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616

1717

1818
class QuestionNodeSerializer(serializers.Serializer):
19-
model_id = serializers.CharField(required=True, label=_("Model id"))
19+
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("Model id"))
20+
model_id_type = serializers.CharField(required=False, default='custom', label=_("Model id type"))
21+
model_id_reference = serializers.ListField(required=False, child=serializers.CharField(), allow_empty=True,
22+
label=_("Reference Field"))
2023
system = serializers.CharField(required=False, allow_blank=True, allow_null=True,
2124
label=_("Role Setting"))
2225
prompt = serializers.CharField(required=True, label=_("Prompt word"))
@@ -42,6 +45,6 @@ def _run(self):
4245
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
4346

4447
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
45-
model_params_setting=None,
48+
model_params_setting=None, model_id_type=None, model_id_reference=None,
4649
**kwargs) -> NodeResult:
4750
pass

apps/application/flow/step_node/question_node/impl/base_question_node.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,23 @@ def save_context(self, details, workflow_manage):
8383
self.answer_text = details.get('answer')
8484

8585
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
86-
model_params_setting=None,
86+
model_params_setting=None, model_id_type=None, model_id_reference=None,
8787
**kwargs) -> NodeResult:
88-
if model_params_setting is None:
88+
# 处理引用类型
89+
if model_id_type == 'reference' and model_id_reference:
90+
reference_data = self.workflow_manage.get_reference_field(
91+
model_id_reference[0],
92+
model_id_reference[1:],
93+
)
94+
if reference_data and isinstance(reference_data, dict):
95+
model_id = reference_data.get('model_id', model_id)
96+
model_params_setting = reference_data.get('model_params_setting')
97+
98+
if model_params_setting is None and model_id:
8999
model_params_setting = get_default_model_params_setting(model_id)
90100
workspace_id = self.workflow_manage.get_body().get('workspace_id')
91101
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
92-
**model_params_setting)
102+
**(model_params_setting or {}))
93103
history_message = self.get_history_message(history_chat_record, dialogue_number)
94104
self.context['history_message'] = history_message
95105
question = self.generate_prompt_question(prompt)

apps/application/flow/step_node/text_to_speech_step_node/impl/base_text_to_speech_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def execute(self, tts_model_id,
6969
self.context['content'] = chunk
7070
workspace_id = self.workflow_manage.get_body().get('workspace_id')
7171
model = get_model_instance_by_model_workspace_id(
72-
tts_model_id, workspace_id, **model_params_setting)
72+
tts_model_id, workspace_id, **(model_params_setting or {}))
7373

7474
audio_byte = model.text_to_speech(chunk)
7575

0 commit comments

Comments
 (0)