diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 77117769769..10bc850e8c6 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -904,6 +904,11 @@ def get_tool_id_list(workflow): 'node_data', {}).get('loop_body', {})) for item in r: _result.append(item) + elif node.get('type') == 'tool-workflow-lib-node': + tool_id = node.get('properties', {}).get( + 'node_data', {}).get('tool_lib_id') + if tool_id: + _result.append(tool_id) elif node.get('type') == 'ai-chat-node': node_data = node.get('properties', {}).get('node_data', {}) mcp_tool_ids = node_data.get('mcp_tool_ids') or [] diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index 5d071fb373d..6e9fb64ad89 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -8,6 +8,7 @@ import re import tempfile import zipfile +from functools import reduce from typing import Dict import requests @@ -32,7 +33,7 @@ from common.exception.app_exception import AppApiException from common.field.common import UploadedImageField from common.result import result -from common.utils.common import get_file_content +from common.utils.common import get_file_content, generate_uuid from common.utils.logger import maxkb_logger from common.utils.rsa_util import rsa_long_decrypt, rsa_long_encrypt from common.utils.tool_code import ToolExecutor @@ -51,6 +52,31 @@ tool_executor = ToolExecutor() +def hand_node(node, update_tool_map): + if node.get('type') == 'tool-lib-node': + tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '') + node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, tool_lib_id) + + if node.get('type') == 'tool-workflow-lib-node': + tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '') + node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, tool_lib_id) + + if node.get('type') == 'search-knowledge-node': + node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = [] + if node.get('type') == 'ai-chat-node': + node_data = node.get('properties', {}).get('node_data', {}) + mcp_tool_ids = node_data.get('mcp_tool_ids') or [] + node_data['mcp_tool_ids'] = [update_tool_map.get(tool_id, + tool_id) for tool_id in mcp_tool_ids] + tool_ids = node_data.get('tool_ids') or [] + node_data['tool_ids'] = [update_tool_map.get(tool_id, + tool_id) for tool_id in tool_ids] + if node.get('type') == 'mcp-node': + mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '') + node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id, + mcp_tool_id) + + class ToolInstance: def __init__(self, tool: dict, version: str): self.tool = tool @@ -631,6 +657,30 @@ def one(self): 'is_publish': is_publish } + def get_child_tool_list(self, work_flow, response): + from application.flow.tools import get_tool_id_list + tool_id_list = get_tool_id_list(work_flow) + tool_id_list = [tool_id for tool_id in tool_id_list if + len([r for r in response if r.get('id') == tool_id]) == 0] + tool_list = [] + if len(tool_id_list) > 0: + tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED) + work_flow_tools = [tool for tool in tool_list if tool.tool_type == ToolType.WORKFLOW] + if len(work_flow_tools) > 0: + work_flow_tool_dict = {tw.tool_id: tw for tw in + QuerySet(ToolWorkflow).filter(tool_id__in=[t.id for t in work_flow_tools])} + for tool in tool_list: + if tool.tool_type == ToolType.WORKFLOW: + response.append({**ToolExportModelSerializer(tool).data, + 'work_flow': work_flow_tool_dict.get(tool.id).work_flow}) + self.get_child_tool_list(work_flow_tool_dict.get(tool.id).work_flow, response) + else: + response.append(ToolExportModelSerializer(tool).data) + else: + for tool in tool_list: + response.append(ToolExportModelSerializer(tool).data) + return response + def export(self): try: self.is_valid() @@ -642,6 +692,11 @@ def export(self): skill_file = QuerySet(File).filter(id=tool.code).first() if skill_file: tool_dict['code'] = base64.b64encode(skill_file.get_bytes()).decode('utf-8') + if tool.tool_type == ToolType.WORKFLOW: + workflow = QuerySet(ToolWorkflow).filter(tool_id=tool.id).first() + if workflow: + tool_dict['work_flow'] = workflow.work_flow + tool_dict['tool_list'] = self.get_child_tool_list(workflow.work_flow, []) mk_instance = ToolInstance(tool_dict, 'v2') tool_pickle = pickle.dumps(mk_instance) response = HttpResponse(content_type='text/plain', content=tool_pickle) @@ -674,7 +729,84 @@ class Import(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_("workspace id")) folder_id = serializers.CharField(required=False, allow_null=True, label=_("folder id")) - # + @staticmethod + def to_tool_workflow(work_flow, update_tool_map): + for node in work_flow.get('nodes', []): + hand_node(node, update_tool_map) + if node.get('type') == 'loop_node': + for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []): + hand_node(n, update_tool_map) + return work_flow + + @staticmethod + def to_tool(tool, workspace_id, user_id): + return Tool(id=tool.get('id'), + user_id=user_id, + name=tool.get('name'), + code=tool.get('code'), + template_id=tool.get('template_id'), + input_field_list=tool.get('input_field_list'), + init_field_list=tool.get('init_field_list'), + is_active=False if len((tool.get('init_field_list') or [])) > 0 else tool.get('is_active'), + tool_type=tool.get('tool_type', 'CUSTOM') or 'CUSTOM', + scope=ToolScope.SHARED if workspace_id == 'None' else ToolScope.WORKSPACE, + folder_id='default' if workspace_id == 'None' else workspace_id, + workspace_id=workspace_id) + + def import_workflow_tools(self, tool, workspace_id, user_id): + tool_list = tool.get('tool_list') or [] + update_tool_map = {} + if len(tool_list) > 0: + tool_id_list = reduce(lambda x, y: [*x, *y], + [[tool.get('id'), generate_uuid((tool.get('id') + workspace_id or ''))] + for tool + in + tool_list], []) + # 存在的工具列表 + exits_tool_id_list = [str(tool.id) for tool in + QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)] + # 需要更新的工具集合 + update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + workspace_id or '')) for tool + in + tool_list if + not exits_tool_id_list.__contains__( + tool.get('id'))} + + tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if + not exits_tool_id_list.__contains__( + tool.get('id')) and not exits_tool_id_list.__contains__( + generate_uuid((tool.get('id') + workspace_id or '')))] + + work_flow = self.to_tool_workflow( + tool.get('work_flow'), + update_tool_map, + ) + tool_model_list = [self.to_tool(tool, workspace_id, user_id) for tool in tool_list] + workflow_tool_model_list = [{'tool_id': t.get('id'), 'workflow': self.to_tool_workflow( + t.get('work_flow'), + update_tool_map, + )} for t in tool_list if tool.get('tool_type') == ToolType.WORKFLOW] + workflow_tool_model_list.append({'tool_id': tool.get('id'), 'workflow': work_flow}) + existing_records = QuerySet(ToolWorkflow).filter( + tool_id__in=[wt.get('tool_id') for wt in workflow_tool_model_list], + workspace_id=workspace_id) + existing_map = { + record.tool_id: record + for record in existing_records + } + QuerySet(ToolWorkflow).bulk_create( + [ToolWorkflow(work_flow=wt.get('workflow'), workspace_id=workspace_id, + tool_id=wt.get('tool_id')) for wt in + workflow_tool_model_list if wt.get('tool_id') not in existing_map]) + + if len(tool_model_list) > 0: + QuerySet(Tool).bulk_create(tool_model_list) + UserResourcePermissionSerializer(data={ + 'workspace_id': self.data.get('workspace_id'), + 'user_id': self.data.get('user_id'), + 'auth_target_type': AuthTargetType.TOOL.value + }).auth_resource_batch([t.id for t in tool_model_list]) + @transaction.atomic def import_(self, scope=ToolScope.WORKSPACE): self.is_valid() @@ -718,7 +850,9 @@ def import_(self, scope=ToolScope.WORKSPACE): is_active=False ) tool_model.save() - + if tool.get('tool_type') == ToolType.WORKFLOW: + tool['id'] = tool_id + self.import_workflow_tools(tool, workspace_id=self.data.get('workspace_id'), user_id=user_id) # 自动授权给创建者 UserResourcePermissionSerializer(data={ 'workspace_id': self.data.get('workspace_id'), diff --git a/apps/tools/serializers/tool_workflow.py b/apps/tools/serializers/tool_workflow.py index cffe8c2a6db..a80cdee4ff1 100644 --- a/apps/tools/serializers/tool_workflow.py +++ b/apps/tools/serializers/tool_workflow.py @@ -90,129 +90,6 @@ def get_tool_list(self): class ToolWorkflowSerializer(serializers.Serializer): - class Import(serializers.Serializer): - user_id = serializers.UUIDField(required=True, label=_('user id')) - workspace_id = serializers.CharField(required=False, label=_('workspace id')) - knowledge_id = serializers.UUIDField(required=True, label=_('knowledge id')) - - @transaction.atomic - def import_(self, instance: dict, is_import_tool, with_valid=True): - if with_valid: - self.is_valid() - ToolWorkflowSerializer(data=instance).is_valid(raise_exception=True) - user_id = self.data.get('user_id') - workspace_id = self.data.get('workspace_id') - tool_id = self.data.get('tool_id') - tool_instance_bytes = instance.get('file').read() - try: - tool_instance = restricted_loads(tool_instance_bytes) - except Exception as e: - raise AppApiException(1001, _("Unsupported file format")) - tool_workflow = tool_instance.work_flow - tool_list = tool_instance.get_tool_list() - update_tool_map = {} - if len(tool_list) > 0: - tool_id_list = reduce(lambda x, y: [*x, *y], - [[tool.get('id'), generate_uuid((tool.get('id') + workspace_id or ''))] - for tool - in - tool_list], []) - # 存在的工具列表 - exits_tool_id_list = [str(tool.id) for tool in - QuerySet(Tool).filter(id__in=tool_id_list, workspace_id=workspace_id)] - # 需要更新的工具集合 - update_tool_map = {tool.get('id'): generate_uuid((tool.get('id') + workspace_id or '')) for tool - in - tool_list if - not exits_tool_id_list.__contains__( - tool.get('id'))} - - tool_list = [{**tool, 'id': update_tool_map.get(tool.get('id'))} for tool in tool_list if - not exits_tool_id_list.__contains__( - tool.get('id')) and not exits_tool_id_list.__contains__( - generate_uuid((tool.get('id') + workspace_id or '')))] - - work_flow = self.to_tool_workflow( - tool_workflow, - update_tool_map, - ) - tool_model_list = [self.to_tool(tool, workspace_id, user_id) for tool in tool_list] - QuerySet(ToolWorkflow).filter(workspace_id=workspace_id, tool_id=tool_id).update_or_create( - tool_id=tool_id, - workspace_id=workspace_id, - defaults={'work_flow': work_flow} - ) - - if is_import_tool: - if len(tool_model_list) > 0: - QuerySet(Tool).bulk_create(tool_model_list) - UserResourcePermissionSerializer(data={ - 'workspace_id': self.data.get('workspace_id'), - 'user_id': self.data.get('user_id'), - 'auth_target_type': AuthTargetType.TOOL.value - }).auth_resource_batch([t.id for t in tool_model_list]) - - @staticmethod - def to_tool_workflow(knowledge_workflow, update_tool_map): - work_flow = knowledge_workflow.get("work_flow") - for node in work_flow.get('nodes', []): - hand_node(node, update_tool_map) - if node.get('type') == 'loop_node': - for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []): - hand_node(n, update_tool_map) - return work_flow - - @staticmethod - def to_tool(tool, workspace_id, user_id): - return Tool(id=tool.get('id'), - user_id=user_id, - name=tool.get('name'), - code=tool.get('code'), - template_id=tool.get('template_id'), - input_field_list=tool.get('input_field_list'), - init_field_list=tool.get('init_field_list'), - is_active=False if len((tool.get('init_field_list') or [])) > 0 else tool.get('is_active'), - tool_type=tool.get('tool_type', 'CUSTOM') or 'CUSTOM', - scope=ToolScope.SHARED if workspace_id == 'None' else ToolScope.WORKSPACE, - folder_id='default' if workspace_id == 'None' else workspace_id, - workspace_id=workspace_id) - - class Export(serializers.Serializer): - user_id = serializers.UUIDField(required=True, label=_('user id')) - workspace_id = serializers.CharField(required=False, label=_('workspace id')) - tool_id = serializers.UUIDField(required=True, label=_('knowledge id')) - - def export(self, with_valid=True): - try: - if with_valid: - self.is_valid() - tool_id = self.data.get('tool_id') - tool_workflow = QuerySet(ToolWorkflow).filter(tool_id=tool_id).first() - tool = QuerySet(Tool).filter(id=tool_id).first() - from application.flow.tools import get_tool_id_list - tool_id_list = get_tool_id_list(tool_workflow.work_flow) - tool_list = [] - if len(tool_id_list) > 0: - tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED) - tool_workflow_dict = {'id': tool.id, - 'work_flow': tool_workflow.work_flow, - 'workspace_id': tool.workspace_id, - 'name': tool.name, - 'desc': tool.desc, - 'tool_type': tool.tool_type} - - tool_workflow_instance = ToolWorkflowInstance( - tool_workflow_dict, - 'v2', - [ToolExportModelSerializer(tool).data for tool in tool_list] - ) - tool_workflow_pickle = pickle.dumps(tool_workflow_instance) - response = HttpResponse(content_type='text/plain', content=tool_workflow_pickle) - response['Content-Disposition'] = f'attachment; filename="{tool.name}.tool"' - return response - except Exception as e: - return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR) - class Operate(serializers.Serializer): user_id = serializers.UUIDField(required=True, label=_('user id')) workspace_id = serializers.CharField(required=True, label=_('workspace id')) diff --git a/apps/tools/urls.py b/apps/tools/urls.py index 075ed3e7bac..c199707dc15 100644 --- a/apps/tools/urls.py +++ b/apps/tools/urls.py @@ -22,7 +22,6 @@ path('workspace//tool//publish', views.ToolWorkflowView.Publish.as_view()), path('workspace//tool//debug', views.ToolWorkflowDebugView.as_view()), path('workspace//tool//workflow', views.ToolWorkflowView.Operate.as_view()), - path('workspace//tool//workflow/export', views.ToolWorkflowView.Export.as_view()), path('workspace//tool//edit_icon', views.ToolView.EditIcon.as_view()), path('workspace//tool//export', views.ToolView.Export.as_view()), path('workspace//tool//add_internal_tool', views.ToolView.AddInternalTool.as_view()), diff --git a/apps/tools/views/tool_workflow.py b/apps/tools/views/tool_workflow.py index 3d2a08b8604..3f009489886 100644 --- a/apps/tools/views/tool_workflow.py +++ b/apps/tools/views/tool_workflow.py @@ -47,73 +47,6 @@ def put(self, request: Request, workspace_id: str, tool_id: str): data={'tool_id': tool_id, 'user_id': request.user.id, 'workspace_id': workspace_id, }).publish()) - class Export(APIView): - authentication_classes = [TokenAuth] - - @extend_schema( - methods=['GET'], - description=_('Export tool workflow'), - summary=_('Export tool workflow'), - operation_id=_('Export tool workflow'), # type: ignore - parameters=ToolWorkflowExportApi.get_parameters(), - request=None, - responses=ToolWorkflowExportApi.get_response(), - tags=[_('Tool')] # type: ignore - ) - @has_permissions( - PermissionConstants.TOOL_EXPORT.get_workspace_tool_permission(), - PermissionConstants.TOOL_EXPORT.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - ViewPermission( - [RoleConstants.USER.get_workspace_role()], - [PermissionConstants.KNOWLEDGE.get_workspace_tool_permission()], - CompareConstants.AND - ) - ) - @log(menu='Tool', operate="Export tool workflow", - get_operation_object=lambda r, k: get_tool_operation_object(k.get('tool_id')), - ) - def get(self, request: Request, workspace_id: str, tool_id: str): - return ToolWorkflowSerializer.Export( - data={'tool_id': tool_id, 'user_id': request.user.id, 'workspace_id': workspace_id} - ).export() - - class Import(APIView): - authentication_classes = [TokenAuth] - - @extend_schema( - methods=['POST'], - description=_('Import tool workflow'), - summary=_('Import tool workflow'), - operation_id=_('Import tool workflow'), # type: ignore - parameters=ToolWorkflowImportApi.get_parameters(), - request=ToolWorkflowImportApi.get_request(), - responses=ToolWorkflowImportApi.get_response(), - tags=[_('Tool')] # type: ignore - ) - @has_permissions( - PermissionConstants.TOOL_EXPORT.get_workspace_tool_permission(), - PermissionConstants.TOOL_EXPORT.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), - ViewPermission( - [RoleConstants.USER.get_workspace_role()], - [PermissionConstants.KNOWLEDGE.get_workspace_tool_permission()], - CompareConstants.AND - ) - ) - @log(menu='Tool', operate="Import tool workflow", - get_operation_object=lambda r, k: get_tool_operation_object(k.get('tool')), - ) - def post(self, request: Request, workspace_id: str, tool_id: str): - is_import_tool = get_is_permissions(request, workspace_id=workspace_id)( - PermissionConstants.TOOL_IMPORT.get_workspace_permission(), - PermissionConstants.TOOL_IMPORT.get_workspace_permission_workspace_manage_role(), - RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.USER.get_workspace_role() - ) - return result.success(ToolWorkflowSerializer.Import(data={ - 'tool_id': tool_id, 'user_id': request.user.id, 'workspace_id': workspace_id - }).import_({'file': request.FILES.get('file')}, is_import_tool)) - class Operate(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/tool/tool.ts b/ui/src/api/tool/tool.ts index de1d38d774d..f5b2ced6b42 100644 --- a/ui/src/api/tool/tool.ts +++ b/ui/src/api/tool/tool.ts @@ -229,21 +229,7 @@ const exportKnowledgeWorkflow = ( loading, ) } -/** - * 导出知识库工作流 - * @param knowledge_id - * @param knowledge_name - * @param loading - * @returns - */ -const exportToolWorkflow = (tool_id: string, tool_name: string, loading?: Ref) => { - return exportFile( - tool_name + '.tool', - `${prefix.value}/${tool_id}/workflow/export`, - undefined, - loading, - ) -} + /** * 导入工具工作流 */ @@ -300,17 +286,11 @@ const debugToolWorkflow: (tool_id: string, data: any) => Promise = (tool_id return postStream(`${p}${prefix.value}/${tool_id}/debug`, data) } -const generateCode: (data:any) => Promise> = ( - data: any, -) => { +const generateCode: (data: any) => Promise> = (data: any) => { const p = (window.MaxKB?.prefix ? window.MaxKB?.prefix : '/admin') + '/api' - return postStream( - `${p}${prefix.value}/generate_code`, - data, - ) + return postStream(`${p}${prefix.value}/generate_code`, data) } - export default { getToolList, getAllToolList, @@ -336,7 +316,6 @@ export default { listToolWorkflowVersion, updateToolWorkflowVersion, publish, - exportToolWorkflow, debugToolWorkflow, generateCode, } diff --git a/ui/src/views/tool-workflow/index.vue b/ui/src/views/tool-workflow/index.vue index 34e69573c7e..2f66fee3f61 100644 --- a/ui/src/views/tool-workflow/index.vue +++ b/ui/src/views/tool-workflow/index.vue @@ -413,7 +413,7 @@ const importKnowledgeWorkflow = (file: any) => { function exportToolWorkflow(name: string, id: string) { loadSharedApi({ type: 'tool', isShared: isShared.value, systemType: apiType.value }) - .exportToolWorkflow(id, name, loading) + .exportTool(id, name, loading) .catch((error: any) => { if (error.response.status !== 403) { error.response.data.text().then((res: string) => {