diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 10bc850e8c6..f2ceb11559a 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -6,7 +6,7 @@ @date:2024/6/6 15:15 @desc: """ -from tools.models import ToolRecord, Tool +from tools.models import ToolRecord, Tool, ToolScope from maxkb.const import CONFIG from knowledge.models.knowledge_action import State from knowledge.models import File @@ -891,7 +891,8 @@ def save_workflow_mapping(workflow, source_type, source_id, other_resource_mappi {(str(item.target_type) + str(item.target_id)): item for item in resource_mapping_list}.values()) -def get_tool_id_list(workflow): +def get_tool_id_list(workflow, with_deep=False): + from tools.models import ToolWorkflow, ToolType _result = [] for node in workflow.get('nodes', []): if node.get('type') == 'tool-lib-node': @@ -921,4 +922,34 @@ def get_tool_id_list(workflow): 'node_data', {}).get('mcp_tool_id') if mcp_tool_id: _result.append(mcp_tool_id) + if with_deep: + workflow_list = QuerySet(Tool).filter(id__in=_result, tool_type=ToolType.WORKFLOW) + tool_work_flow_list = QuerySet(ToolWorkflow).filter(tool_id__in=[wl.id for wl in workflow_list]) + for tool_work_flow in tool_work_flow_list: + child_tool_id_list = get_child_tool_id_list(tool_work_flow.work_flow, []) + for c in child_tool_id_list: + _result.append(c) return _result + + +def get_child_tool_id_list(work_flow, response): + from tools.models import ToolWorkflow, ToolType + tool_id_list = get_tool_id_list(work_flow, False) + tool_id_list = [tool_id for tool_id in tool_id_list if + len([r for r in response if r == tool_id]) == 0] + tool_list = [] + if len(tool_id_list) > 0: + tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED) + work_flow_tools = [tool for tool in tool_list if tool.tool_type == ToolType.WORKFLOW] + if len(work_flow_tools) > 0: + + work_flow_tool_dict = {tw.tool_id: tw for tw in + QuerySet(ToolWorkflow).filter(tool_id__in=[t.id for t in work_flow_tools])} + for tool in tool_list: + response.append(str(tool.id)) + if tool.tool_type == ToolType.WORKFLOW: + get_child_tool_id_list(work_flow_tool_dict.get(tool.id).work_flow, response) + else: + for tool in tool_list: + response.append(str(tool.id)) + return response diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index abbcfd2b685..8bef8cc617e 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -55,7 +55,7 @@ from system_manage.models.resource_mapping import ResourceMapping from system_manage.serializers.resource_mapping_serializers import ResourceMappingSerializer from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer -from tools.models import Tool, ToolScope, ToolType +from tools.models import Tool, ToolScope, ToolType, ToolWorkflow from tools.serializers.tool import ToolExportModelSerializer from trigger.models import TriggerTask, Trigger from users.models import User @@ -93,6 +93,10 @@ def hand_node(node, update_tool_map): mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '') node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id, mcp_tool_id) + if node.get('type') == 'tool-workflow-lib-node': + tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '') + node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, + tool_lib_id) class MKInstance: @@ -628,6 +632,12 @@ def import_(self, instance: dict, is_import_tool, with_valid=True): if is_import_tool: if len(tool_model_list) > 0: QuerySet(Tool).bulk_create(tool_model_list) + QuerySet(ToolWorkflow).bulk_create( + [ToolWorkflow(workspace_id=workspace_id, + work_flow=self.reset_workflow(tool.get('work_flow'), update_tool_map), + tool_id=tool.get('id')) + for + tool in tool_list if tool.get('tool_type') == ToolType.WORKFLOW]) UserResourcePermissionSerializer(data={ 'workspace_id': self.data.get('workspace_id'), 'user_id': self.data.get('user_id'), @@ -670,6 +680,15 @@ def to_tool(tool, workspace_id, user_id): folder_id=workspace_id, workspace_id=workspace_id) + @staticmethod + def reset_workflow(work_flow, update_tool_map): + for node in work_flow.get('nodes', []): + hand_node(node, update_tool_map) + if node.get('type') == 'loop-node': + for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []): + hand_node(n, update_tool_map) + return work_flow + @staticmethod def to_application(application, workspace_id, user_id, update_tool_map, folder_id): work_flow = application.get('work_flow') @@ -844,13 +863,16 @@ def export(self, with_valid=True): application_id = self.data.get('application_id') application = QuerySet(Application).filter(id=application_id).first() from application.flow.tools import get_tool_id_list - tool_id_list = get_tool_id_list(application.work_flow) + tool_id_list = get_tool_id_list(application.work_flow, True) if len(tool_id_list) > 0: tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED) else: tool_list = QuerySet(Tool).filter( id__in=application.tool_ids + application.mcp_tool_ids + application.skill_tool_ids ).exclude(scope=ToolScope.SHARED) + tw_dict = {tw.tool_id: tw + for tw in QuerySet(ToolWorkflow).filter( + tool_id__in=[tool.id for tool in tool_list if tool.tool_type == ToolType.WORKFLOW])} # 如果是技能工具,则需要将code字段转换为文件内容的base64字符串 for tool in tool_list: if tool.tool_type == ToolType.SKILL: @@ -862,7 +884,7 @@ def export(self, with_valid=True): mk_instance = MKInstance(application_dict, [], 'v2', - [ToolExportModelSerializer(tool).data for tool in + [self.to_tool_dict(tool, tw_dict) for tool in tool_list]) application_pickle = pickle.dumps(mk_instance) response = HttpResponse(content_type='text/plain', content=application_pickle) @@ -871,6 +893,12 @@ def export(self, with_valid=True): except Exception as e: return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR) + @staticmethod + def to_tool_dict(tool, tool_workflow_dict): + if tool.tool_type == ToolType.WORKFLOW: + return {**ToolExportModelSerializer(tool).data, 'work_flow': tool_workflow_dict.get(tool.id).work_flow} + return ToolExportModelSerializer(tool).data + @staticmethod def reset_application_version(application_version, application): update_field_dict = { @@ -1329,7 +1357,7 @@ def batch_delete(self, instance: Dict, with_valid=True): from trigger.serializers.trigger import TriggerModelSerializer if with_valid: - BatchSerializer(data=instance).is_valid(model=Application,raise_exception=True) + BatchSerializer(data=instance).is_valid(model=Application, raise_exception=True) self.is_valid(raise_exception=True) id_list = instance.get("id_list") workspace_id = self.data.get('workspace_id') diff --git a/apps/knowledge/serializers/knowledge_workflow.py b/apps/knowledge/serializers/knowledge_workflow.py index c9ffc608a99..24cf9a2d1e2 100644 --- a/apps/knowledge/serializers/knowledge_workflow.py +++ b/apps/knowledge/serializers/knowledge_workflow.py @@ -38,7 +38,7 @@ from system_manage.models import AuthTargetType from system_manage.models.resource_mapping import ResourceType from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer -from tools.models import Tool, ToolScope +from tools.models import Tool, ToolScope, ToolType, ToolWorkflow from tools.serializers.tool import ToolExportModelSerializer from users.models import User @@ -64,6 +64,10 @@ def hand_node(node, update_tool_map): mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '') node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id, mcp_tool_id) + if node.get('type') == 'tool-workflow-lib-node': + tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '') + node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, + tool_lib_id) class KnowledgeWorkflowModelSerializer(serializers.ModelSerializer): @@ -356,6 +360,12 @@ def import_(self, instance: dict, is_import_tool, with_valid=True): if is_import_tool: if len(tool_model_list) > 0: QuerySet(Tool).bulk_create(tool_model_list) + QuerySet(ToolWorkflow).bulk_create( + [ToolWorkflow(workspace_id=workspace_id, + work_flow=self.reset_workflow(tool.get('work_flow'), update_tool_map), + tool_id=tool.get('id')) + for + tool in tool_list if tool.get('tool_type') == ToolType.WORKFLOW]) UserResourcePermissionSerializer(data={ 'workspace_id': self.data.get('workspace_id'), 'user_id': self.data.get('user_id'), @@ -374,6 +384,15 @@ def to_knowledge_workflow(knowledge_workflow, update_tool_map): hand_node(n, update_tool_map) return work_flow + @staticmethod + def reset_workflow(work_flow, update_tool_map): + for node in work_flow.get('nodes', []): + hand_node(node, update_tool_map) + if node.get('type') == 'loop-node': + for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []): + hand_node(n, update_tool_map) + return work_flow + @staticmethod def to_tool(tool, workspace_id, user_id): return Tool(id=tool.get('id'), @@ -402,17 +421,20 @@ def export(self, with_valid=True): knowledge_workflow = QuerySet(KnowledgeWorkflow).filter(knowledge_id=knowledge_id).first() knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first() from application.flow.tools import get_tool_id_list - tool_id_list = get_tool_id_list(knowledge_workflow.work_flow) + tool_id_list = get_tool_id_list(knowledge_workflow.work_flow, True) tool_list = [] if len(tool_id_list) > 0: tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED) + tw_dict = {tw.tool_id: tw + for tw in QuerySet(ToolWorkflow).filter( + tool_id__in=[tool.id for tool in tool_list if tool.tool_type == ToolType.WORKFLOW])} knowledge_workflow_dict = KnowledgeWorkflowModelSerializer(knowledge_workflow).data kbwf_instance = KBWFInstance( knowledge_workflow_dict, [], 'v2', - [ToolExportModelSerializer(tool).data for tool in tool_list] + [self.to_tool_dict(tool, tw_dict) for tool in tool_list] ) knowledge_workflow_pickle = pickle.dumps(kbwf_instance) response = HttpResponse(content_type='text/plain', content=knowledge_workflow_pickle) @@ -421,6 +443,12 @@ def export(self, with_valid=True): except Exception as e: return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR) + @staticmethod + def to_tool_dict(tool, tool_workflow_dict): + if tool.tool_type == ToolType.WORKFLOW: + return {**ToolExportModelSerializer(tool).data, 'work_flow': tool_workflow_dict.get(tool.id).work_flow} + return ToolExportModelSerializer(tool).data + class Operate(serializers.Serializer): user_id = serializers.UUIDField(required=True, label=_('user id')) workspace_id = serializers.CharField(required=True, label=_('workspace id')) diff --git a/apps/tools/serializers/tool.py b/apps/tools/serializers/tool.py index 6e9fb64ad89..3cd4715d7c3 100644 --- a/apps/tools/serializers/tool.py +++ b/apps/tools/serializers/tool.py @@ -659,7 +659,7 @@ def one(self): def get_child_tool_list(self, work_flow, response): from application.flow.tools import get_tool_id_list - tool_id_list = get_tool_id_list(work_flow) + tool_id_list = get_tool_id_list(work_flow, False) tool_id_list = [tool_id for tool_id in tool_id_list if len([r for r in response if r.get('id') == tool_id]) == 0] tool_list = [] @@ -1307,6 +1307,7 @@ def process(): return to_stream_response_simple(process()) + class ToolBatchOperateSerializer(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id')) @@ -1365,8 +1366,6 @@ def batch_move(self, instance: Dict, with_valid=True): return True - - class ToolTreeSerializer(serializers.Serializer): class Query(serializers.Serializer): workspace_id = serializers.CharField(required=True, label=_('workspace id'))