Skip to content

Commit ec1ef59

Browse files
authored
fix: Workflow tool import and export (#4976)
1 parent 7e6e5de commit ec1ef59

File tree

7 files changed

+146
-219
lines changed

7 files changed

+146
-219
lines changed

apps/application/flow/tools.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,11 @@ def get_tool_id_list(workflow):
904904
'node_data', {}).get('loop_body', {}))
905905
for item in r:
906906
_result.append(item)
907+
elif node.get('type') == 'tool-workflow-lib-node':
908+
tool_id = node.get('properties', {}).get(
909+
'node_data', {}).get('tool_lib_id')
910+
if tool_id:
911+
_result.append(tool_id)
907912
elif node.get('type') == 'ai-chat-node':
908913
node_data = node.get('properties', {}).get('node_data', {})
909914
mcp_tool_ids = node_data.get('mcp_tool_ids') or []

apps/tools/serializers/tool.py

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import tempfile
1010
import zipfile
11+
from functools import reduce
1112
from typing import Dict
1213

1314
import requests
@@ -32,7 +33,7 @@
3233
from common.exception.app_exception import AppApiException
3334
from common.field.common import UploadedImageField
3435
from common.result import result
35-
from common.utils.common import get_file_content
36+
from common.utils.common import get_file_content, generate_uuid
3637
from common.utils.logger import maxkb_logger
3738
from common.utils.rsa_util import rsa_long_decrypt, rsa_long_encrypt
3839
from common.utils.tool_code import ToolExecutor
@@ -51,6 +52,31 @@
5152
tool_executor = ToolExecutor()
5253

5354

55+
def hand_node(node, update_tool_map):
56+
if node.get('type') == 'tool-lib-node':
57+
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
58+
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, tool_lib_id)
59+
60+
if node.get('type') == 'tool-workflow-lib-node':
61+
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
62+
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, tool_lib_id)
63+
64+
if node.get('type') == 'search-knowledge-node':
65+
node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = []
66+
if node.get('type') == 'ai-chat-node':
67+
node_data = node.get('properties', {}).get('node_data', {})
68+
mcp_tool_ids = node_data.get('mcp_tool_ids') or []
69+
node_data['mcp_tool_ids'] = [update_tool_map.get(tool_id,
70+
tool_id) for tool_id in mcp_tool_ids]
71+
tool_ids = node_data.get('tool_ids') or []
72+
node_data['tool_ids'] = [update_tool_map.get(tool_id,
73+
tool_id) for tool_id in tool_ids]
74+
if node.get('type') == 'mcp-node':
75+
mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
76+
node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
77+
mcp_tool_id)
78+
79+
5480
class ToolInstance:
5581
def __init__(self, tool: dict, version: str):
5682
self.tool = tool
@@ -631,6 +657,30 @@ def one(self):
631657
'is_publish': is_publish
632658
}
633659

660+
def get_child_tool_list(self, work_flow, response):
661+
from application.flow.tools import get_tool_id_list
662+
tool_id_list = get_tool_id_list(work_flow)
663+
tool_id_list = [tool_id for tool_id in tool_id_list if
664+
len([r for r in response if r.get('id') == tool_id]) == 0]
665+
tool_list = []
666+
if len(tool_id_list) > 0:
667+
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
668+
work_flow_tools = [tool for tool in tool_list if tool.tool_type == ToolType.WORKFLOW]
669+
if len(work_flow_tools) > 0:
670+
work_flow_tool_dict = {tw.tool_id: tw for tw in
671+
QuerySet(ToolWorkflow).filter(tool_id__in=[t.id for t in work_flow_tools])}
672+
for tool in tool_list:
673+
if tool.tool_type == ToolType.WORKFLOW:
674+
response.append({**ToolExportModelSerializer(tool).data,
675+
'work_flow': work_flow_tool_dict.get(tool.id).work_flow})
676+
self.get_child_tool_list(work_flow_tool_dict.get(tool.id).work_flow, response)
677+
else:
678+
response.append(ToolExportModelSerializer(tool).data)
679+
else:
680+
for tool in tool_list:
681+
response.append(ToolExportModelSerializer(tool).data)
682+
return response
683+
634684
def export(self):
635685
try:
636686
self.is_valid()
@@ -642,6 +692,11 @@ def export(self):
642692
skill_file = QuerySet(File).filter(id=tool.code).first()
643693
if skill_file:
644694
tool_dict['code'] = base64.b64encode(skill_file.get_bytes()).decode('utf-8')
695+
if tool.tool_type == ToolType.WORKFLOW:
696+
workflow = QuerySet(ToolWorkflow).filter(tool_id=tool.id).first()
697+
if workflow:
698+
tool_dict['work_flow'] = workflow.work_flow
699+
tool_dict['tool_list'] = self.get_child_tool_list(workflow.work_flow, [])
645700
mk_instance = ToolInstance(tool_dict, 'v2')
646701
tool_pickle = pickle.dumps(mk_instance)
647702
response = HttpResponse(content_type='text/plain', content=tool_pickle)
@@ -674,7 +729,84 @@ class Import(serializers.Serializer):
674729
workspace_id = serializers.CharField(required=True, label=_("workspace id"))
675730
folder_id = serializers.CharField(required=False, allow_null=True, label=_("folder id"))
676731

677-
#
732+
@staticmethod
733+
def to_tool_workflow(work_flow, update_tool_map):
734+
for node in work_flow.get('nodes', []):
735+
hand_node(node, update_tool_map)
736+
if node.get('type') == 'loop_node':
737+
for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []):
738+
hand_node(n, update_tool_map)
739+
return work_flow
740+
741+
@staticmethod
742+
def to_tool(tool, workspace_id, user_id):
743+
return Tool(id=tool.get('id'),
744+
user_id=user_id,
745+
name=tool.get('name'),
746+
code=tool.get('code'),
747+
template_id=tool.get('template_id'),
748+
input_field_list=tool.get('input_field_list'),
749+
init_field_list=tool.get('init_field_list'),
750+
is_active=False if len((tool.get('init_field_list') or [])) > 0 else tool.get('is_active'),
751+
tool_type=tool.get('tool_type', 'CUSTOM') or 'CUSTOM',
752+
scope=ToolScope.SHARED if workspace_id == 'None' else ToolScope.WORKSPACE,
753+
folder_id='default' if workspace_id == 'None' else workspace_id,
754+
workspace_id=workspace_id)
755+
756+
def import_workflow_tools(self, tool, workspace_id, user_id):
757+
tool_list = tool.get('tool_list') or []
758+
update_tool_map = {}
759+
if len(tool_list) > 0:
760+
tool_id_list = reduce(lambda x, y: [*x, *y],
761+
[[tool.get('id'), generate_uuid((tool.get('id') + workspace_id or ''))]
762+
for tool
763+
in
764+
tool_list], [])
765+
# 存在的工具列表
766+
exits_tool_id_list = [str(tool.id) for tool in
767+
QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)]
768+
# 需要更新的工具集合
769+
update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + workspace_id or '')) for tool
770+
in
771+
tool_list if
772+
not exits_tool_id_list.__contains__(
773+
tool.get('id'))}
774+
775+
tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if
776+
not exits_tool_id_list.__contains__(
777+
tool.get('id')) and not exits_tool_id_list.__contains__(
778+
generate_uuid((tool.get('id') + workspace_id or '')))]
779+
780+
work_flow = self.to_tool_workflow(
781+
tool.get('work_flow'),
782+
update_tool_map,
783+
)
784+
tool_model_list = [self.to_tool(tool, workspace_id, user_id) for tool in tool_list]
785+
workflow_tool_model_list = [{'tool_id': t.get('id'), 'workflow': self.to_tool_workflow(
786+
t.get('work_flow'),
787+
update_tool_map,
788+
)} for t in tool_list if tool.get('tool_type') == ToolType.WORKFLOW]
789+
workflow_tool_model_list.append({'tool_id': tool.get('id'), 'workflow': work_flow})
790+
existing_records = QuerySet(ToolWorkflow).filter(
791+
tool_id__in=[wt.get('tool_id') for wt in workflow_tool_model_list],
792+
workspace_id=workspace_id)
793+
existing_map = {
794+
record.tool_id: record
795+
for record in existing_records
796+
}
797+
QuerySet(ToolWorkflow).bulk_create(
798+
[ToolWorkflow(work_flow=wt.get('workflow'), workspace_id=workspace_id,
799+
tool_id=wt.get('tool_id')) for wt in
800+
workflow_tool_model_list if wt.get('tool_id') not in existing_map])
801+
802+
if len(tool_model_list) > 0:
803+
QuerySet(Tool).bulk_create(tool_model_list)
804+
UserResourcePermissionSerializer(data={
805+
'workspace_id': self.data.get('workspace_id'),
806+
'user_id': self.data.get('user_id'),
807+
'auth_target_type': AuthTargetType.TOOL.value
808+
}).auth_resource_batch([t.id for t in tool_model_list])
809+
678810
@transaction.atomic
679811
def import_(self, scope=ToolScope.WORKSPACE):
680812
self.is_valid()
@@ -718,7 +850,9 @@ def import_(self, scope=ToolScope.WORKSPACE):
718850
is_active=False
719851
)
720852
tool_model.save()
721-
853+
if tool.get('tool_type') == ToolType.WORKFLOW:
854+
tool['id'] = tool_id
855+
self.import_workflow_tools(tool, workspace_id=self.data.get('workspace_id'), user_id=user_id)
722856
# 自动授权给创建者
723857
UserResourcePermissionSerializer(data={
724858
'workspace_id': self.data.get('workspace_id'),

apps/tools/serializers/tool_workflow.py

Lines changed: 0 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -90,129 +90,6 @@ def get_tool_list(self):
9090

9191

9292
class ToolWorkflowSerializer(serializers.Serializer):
93-
class Import(serializers.Serializer):
94-
user_id = serializers.UUIDField(required=True, label=_('user id'))
95-
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
96-
knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id'))
97-
98-
@transaction.atomic
99-
def import_(self, instance: dict, is_import_tool, with_valid=True):
100-
if with_valid:
101-
self.is_valid()
102-
ToolWorkflowSerializer(data=instance).is_valid(raise_exception=True)
103-
user_id = self.data.get('user_id')
104-
workspace_id = self.data.get('workspace_id')
105-
tool_id = self.data.get('tool_id')
106-
tool_instance_bytes = instance.get('file').read()
107-
try:
108-
tool_instance = restricted_loads(tool_instance_bytes)
109-
except Exception as e:
110-
raise AppApiException(1001, _("Unsupported file format"))
111-
tool_workflow = tool_instance.work_flow
112-
tool_list = tool_instance.get_tool_list()
113-
update_tool_map = {}
114-
if len(tool_list) > 0:
115-
tool_id_list = reduce(lambda x, y: [*x, *y],
116-
[[tool.get('id'), generate_uuid((tool.get('id') + workspace_id or ''))]
117-
for tool
118-
in
119-
tool_list], [])
120-
# 存在的工具列表
121-
exits_tool_id_list = [str(tool.id) for tool in
122-
QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)]
123-
# 需要更新的工具集合
124-
update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + workspace_id or '')) for tool
125-
in
126-
tool_list if
127-
not exits_tool_id_list.__contains__(
128-
tool.get('id'))}
129-
130-
tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if
131-
not exits_tool_id_list.__contains__(
132-
tool.get('id')) and not exits_tool_id_list.__contains__(
133-
generate_uuid((tool.get('id') + workspace_id or '')))]
134-
135-
work_flow = self.to_tool_workflow(
136-
tool_workflow,
137-
update_tool_map,
138-
)
139-
tool_model_list = [self.to_tool(tool, workspace_id, user_id) for tool in tool_list]
140-
QuerySet(ToolWorkflow).filter(workspace_id=workspace_id, tool_id=tool_id).update_or_create(
141-
tool_id=tool_id,
142-
workspace_id=workspace_id,
143-
defaults={'work_flow': work_flow}
144-
)
145-
146-
if is_import_tool:
147-
if len(tool_model_list) > 0:
148-
QuerySet(Tool).bulk_create(tool_model_list)
149-
UserResourcePermissionSerializer(data={
150-
'workspace_id': self.data.get('workspace_id'),
151-
'user_id': self.data.get('user_id'),
152-
'auth_target_type': AuthTargetType.TOOL.value
153-
}).auth_resource_batch([t.id for t in tool_model_list])
154-
155-
@staticmethod
156-
def to_tool_workflow(knowledge_workflow, update_tool_map):
157-
work_flow = knowledge_workflow.get("work_flow")
158-
for node in work_flow.get('nodes', []):
159-
hand_node(node, update_tool_map)
160-
if node.get('type') == 'loop_node':
161-
for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []):
162-
hand_node(n, update_tool_map)
163-
return work_flow
164-
165-
@staticmethod
166-
def to_tool(tool, workspace_id, user_id):
167-
return Tool(id=tool.get('id'),
168-
user_id=user_id,
169-
name=tool.get('name'),
170-
code=tool.get('code'),
171-
template_id=tool.get('template_id'),
172-
input_field_list=tool.get('input_field_list'),
173-
init_field_list=tool.get('init_field_list'),
174-
is_active=False if len((tool.get('init_field_list') or [])) > 0 else tool.get('is_active'),
175-
tool_type=tool.get('tool_type', 'CUSTOM') or 'CUSTOM',
176-
scope=ToolScope.SHARED if workspace_id == 'None' else ToolScope.WORKSPACE,
177-
folder_id='default' if workspace_id == 'None' else workspace_id,
178-
workspace_id=workspace_id)
179-
180-
class Export(serializers.Serializer):
181-
user_id = serializers.UUIDField(required=True, label=_('user id'))
182-
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
183-
tool_id = serializers.UUIDField(required=True, label=_('knowledge id'))
184-
185-
def export(self, with_valid=True):
186-
try:
187-
if with_valid:
188-
self.is_valid()
189-
tool_id = self.data.get('tool_id')
190-
tool_workflow = QuerySet(ToolWorkflow).filter(tool_id=tool_id).first()
191-
tool = QuerySet(Tool).filter(id=tool_id).first()
192-
from application.flow.tools import get_tool_id_list
193-
tool_id_list = get_tool_id_list(tool_workflow.work_flow)
194-
tool_list = []
195-
if len(tool_id_list) > 0:
196-
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
197-
tool_workflow_dict = {'id': tool.id,
198-
'work_flow': tool_workflow.work_flow,
199-
'workspace_id': tool.workspace_id,
200-
'name': tool.name,
201-
'desc': tool.desc,
202-
'tool_type': tool.tool_type}
203-
204-
tool_workflow_instance = ToolWorkflowInstance(
205-
tool_workflow_dict,
206-
'v2',
207-
[ToolExportModelSerializer(tool).data for tool in tool_list]
208-
)
209-
tool_workflow_pickle = pickle.dumps(tool_workflow_instance)
210-
response = HttpResponse(content_type='text/plain', content=tool_workflow_pickle)
211-
response['Content-Disposition'] = f'attachment; filename="{tool.name}.tool"'
212-
return response
213-
except Exception as e:
214-
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)
215-
21693
class Operate(serializers.Serializer):
21794
user_id = serializers.UUIDField(required=True, label=_('user id'))
21895
workspace_id = serializers.CharField(required=True, label=_('workspace id'))

apps/tools/urls.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
path('workspace/<str:workspace_id>/tool/<str:tool_id>/publish', views.ToolWorkflowView.Publish.as_view()),
2323
path('workspace/<str:workspace_id>/tool/<str:tool_id>/debug', views.ToolWorkflowDebugView.as_view()),
2424
path('workspace/<str:workspace_id>/tool/<str:tool_id>/workflow', views.ToolWorkflowView.Operate.as_view()),
25-
path('workspace/<str:workspace_id>/tool/<str:tool_id>/workflow/export', views.ToolWorkflowView.Export.as_view()),
2625
path('workspace/<str:workspace_id>/tool/<str:tool_id>/edit_icon', views.ToolView.EditIcon.as_view()),
2726
path('workspace/<str:workspace_id>/tool/<str:tool_id>/export', views.ToolView.Export.as_view()),
2827
path('workspace/<str:workspace_id>/tool/<str:tool_id>/add_internal_tool', views.ToolView.AddInternalTool.as_view()),

0 commit comments

Comments
 (0)