88import re
99import tempfile
1010import zipfile
11+ from functools import reduce
1112from typing import Dict
1213
1314import requests
3233from common .exception .app_exception import AppApiException
3334from common .field .common import UploadedImageField
3435from common .result import result
35- from common .utils .common import get_file_content
36+ from common .utils .common import get_file_content , generate_uuid
3637from common .utils .logger import maxkb_logger
3738from common .utils .rsa_util import rsa_long_decrypt , rsa_long_encrypt
3839from common .utils .tool_code import ToolExecutor
5152tool_executor = ToolExecutor ()
5253
5354
55+ def hand_node (node , update_tool_map ):
56+ if node .get ('type' ) == 'tool-lib-node' :
57+ tool_lib_id = (node .get ('properties' , {}).get ('node_data' , {}).get ('tool_lib_id' ) or '' )
58+ node .get ('properties' , {}).get ('node_data' , {})['tool_lib_id' ] = update_tool_map .get (tool_lib_id , tool_lib_id )
59+
60+ if node .get ('type' ) == 'tool-workflow-lib-node' :
61+ tool_lib_id = (node .get ('properties' , {}).get ('node_data' , {}).get ('tool_lib_id' ) or '' )
62+ node .get ('properties' , {}).get ('node_data' , {})['tool_lib_id' ] = update_tool_map .get (tool_lib_id , tool_lib_id )
63+
64+ if node .get ('type' ) == 'search-knowledge-node' :
65+ node .get ('properties' , {}).get ('node_data' , {})['knowledge_id_list' ] = []
66+ if node .get ('type' ) == 'ai-chat-node' :
67+ node_data = node .get ('properties' , {}).get ('node_data' , {})
68+ mcp_tool_ids = node_data .get ('mcp_tool_ids' ) or []
69+ node_data ['mcp_tool_ids' ] = [update_tool_map .get (tool_id ,
70+ tool_id ) for tool_id in mcp_tool_ids ]
71+ tool_ids = node_data .get ('tool_ids' ) or []
72+ node_data ['tool_ids' ] = [update_tool_map .get (tool_id ,
73+ tool_id ) for tool_id in tool_ids ]
74+ if node .get ('type' ) == 'mcp-node' :
75+ mcp_tool_id = (node .get ('properties' , {}).get ('node_data' , {}).get ('mcp_tool_id' ) or '' )
76+ node .get ('properties' , {}).get ('node_data' , {})['mcp_tool_id' ] = update_tool_map .get (mcp_tool_id ,
77+ mcp_tool_id )
78+
79+
5480class ToolInstance :
5581 def __init__ (self , tool : dict , version : str ):
5682 self .tool = tool
@@ -631,6 +657,30 @@ def one(self):
631657 'is_publish' : is_publish
632658 }
633659
660+ def get_child_tool_list (self , work_flow , response ):
661+ from application .flow .tools import get_tool_id_list
662+ tool_id_list = get_tool_id_list (work_flow )
663+ tool_id_list = [tool_id for tool_id in tool_id_list if
664+ len ([r for r in response if r .get ('id' ) == tool_id ]) == 0 ]
665+ tool_list = []
666+ if len (tool_id_list ) > 0 :
667+ tool_list = QuerySet (Tool ).filter (id__in = tool_id_list ).exclude (scope = ToolScope .SHARED )
668+ work_flow_tools = [tool for tool in tool_list if tool .tool_type == ToolType .WORKFLOW ]
669+ if len (work_flow_tools ) > 0 :
670+ work_flow_tool_dict = {tw .tool_id : tw for tw in
671+ QuerySet (ToolWorkflow ).filter (tool_id__in = [t .id for t in work_flow_tools ])}
672+ for tool in tool_list :
673+ if tool .tool_type == ToolType .WORKFLOW :
674+ response .append ({** ToolExportModelSerializer (tool ).data ,
675+ 'work_flow' : work_flow_tool_dict .get (tool .id ).work_flow })
676+ self .get_child_tool_list (work_flow_tool_dict .get (tool .id ).work_flow , response )
677+ else :
678+ response .append (ToolExportModelSerializer (tool ).data )
679+ else :
680+ for tool in tool_list :
681+ response .append (ToolExportModelSerializer (tool ).data )
682+ return response
683+
634684 def export (self ):
635685 try :
636686 self .is_valid ()
@@ -642,6 +692,11 @@ def export(self):
642692 skill_file = QuerySet (File ).filter (id = tool .code ).first ()
643693 if skill_file :
644694 tool_dict ['code' ] = base64 .b64encode (skill_file .get_bytes ()).decode ('utf-8' )
695+ if tool .tool_type == ToolType .WORKFLOW :
696+ workflow = QuerySet (ToolWorkflow ).filter (tool_id = tool .id ).first ()
697+ if workflow :
698+ tool_dict ['work_flow' ] = workflow .work_flow
699+ tool_dict ['tool_list' ] = self .get_child_tool_list (workflow .work_flow , [])
645700 mk_instance = ToolInstance (tool_dict , 'v2' )
646701 tool_pickle = pickle .dumps (mk_instance )
647702 response = HttpResponse (content_type = 'text/plain' , content = tool_pickle )
@@ -674,7 +729,84 @@ class Import(serializers.Serializer):
674729 workspace_id = serializers .CharField (required = True , label = _ ("workspace id" ))
675730 folder_id = serializers .CharField (required = False , allow_null = True , label = _ ("folder id" ))
676731
677- #
732+ @staticmethod
733+ def to_tool_workflow (work_flow , update_tool_map ):
734+ for node in work_flow .get ('nodes' , []):
735+ hand_node (node , update_tool_map )
736+ if node .get ('type' ) == 'loop_node' :
737+ for n in node .get ('properties' , {}).get ('node_data' , {}).get ('loop_body' , {}).get ('nodes' , []):
738+ hand_node (n , update_tool_map )
739+ return work_flow
740+
741+ @staticmethod
742+ def to_tool (tool , workspace_id , user_id ):
743+ return Tool (id = tool .get ('id' ),
744+ user_id = user_id ,
745+ name = tool .get ('name' ),
746+ code = tool .get ('code' ),
747+ template_id = tool .get ('template_id' ),
748+ input_field_list = tool .get ('input_field_list' ),
749+ init_field_list = tool .get ('init_field_list' ),
750+ is_active = False if len ((tool .get ('init_field_list' ) or [])) > 0 else tool .get ('is_active' ),
751+ tool_type = tool .get ('tool_type' , 'CUSTOM' ) or 'CUSTOM' ,
752+ scope = ToolScope .SHARED if workspace_id == 'None' else ToolScope .WORKSPACE ,
753+ folder_id = 'default' if workspace_id == 'None' else workspace_id ,
754+ workspace_id = workspace_id )
755+
756+ def import_workflow_tools (self , tool , workspace_id , user_id ):
757+ tool_list = tool .get ('tool_list' ) or []
758+ update_tool_map = {}
759+ if len (tool_list ) > 0 :
760+ tool_id_list = reduce (lambda x , y : [* x , * y ],
761+ [[tool .get ('id' ), generate_uuid ((tool .get ('id' ) + workspace_id or '' ))]
762+ for tool
763+ in
764+ tool_list ], [])
765+ # 存在的工具列表
766+ exits_tool_id_list = [str (tool .id ) for tool in
767+ QuerySet (Tool ).filter (id__in = tool_id_list , workspace_id = workspace_id )]
768+ # 需要更新的工具集合
769+ update_tool_map = {tool .get ('id' ): generate_uuid ((tool .get ('id' ) + workspace_id or '' )) for tool
770+ in
771+ tool_list if
772+ not exits_tool_id_list .__contains__ (
773+ tool .get ('id' ))}
774+
775+ tool_list = [{** tool , 'id' : update_tool_map .get (tool .get ('id' ))} for tool in tool_list if
776+ not exits_tool_id_list .__contains__ (
777+ tool .get ('id' )) and not exits_tool_id_list .__contains__ (
778+ generate_uuid ((tool .get ('id' ) + workspace_id or '' )))]
779+
780+ work_flow = self .to_tool_workflow (
781+ tool .get ('work_flow' ),
782+ update_tool_map ,
783+ )
784+ tool_model_list = [self .to_tool (tool , workspace_id , user_id ) for tool in tool_list ]
785+ workflow_tool_model_list = [{'tool_id' : t .get ('id' ), 'workflow' : self .to_tool_workflow (
786+ t .get ('work_flow' ),
787+ update_tool_map ,
788+ )} for t in tool_list if tool .get ('tool_type' ) == ToolType .WORKFLOW ]
789+ workflow_tool_model_list .append ({'tool_id' : tool .get ('id' ), 'workflow' : work_flow })
790+ existing_records = QuerySet (ToolWorkflow ).filter (
791+ tool_id__in = [wt .get ('tool_id' ) for wt in workflow_tool_model_list ],
792+ workspace_id = workspace_id )
793+ existing_map = {
794+ record .tool_id : record
795+ for record in existing_records
796+ }
797+ QuerySet (ToolWorkflow ).bulk_create (
798+ [ToolWorkflow (work_flow = wt .get ('workflow' ), workspace_id = workspace_id ,
799+ tool_id = wt .get ('tool_id' )) for wt in
800+ workflow_tool_model_list if wt .get ('tool_id' ) not in existing_map ])
801+
802+ if len (tool_model_list ) > 0 :
803+ QuerySet (Tool ).bulk_create (tool_model_list )
804+ UserResourcePermissionSerializer (data = {
805+ 'workspace_id' : self .data .get ('workspace_id' ),
806+ 'user_id' : self .data .get ('user_id' ),
807+ 'auth_target_type' : AuthTargetType .TOOL .value
808+ }).auth_resource_batch ([t .id for t in tool_model_list ])
809+
678810 @transaction .atomic
679811 def import_ (self , scope = ToolScope .WORKSPACE ):
680812 self .is_valid ()
@@ -718,7 +850,9 @@ def import_(self, scope=ToolScope.WORKSPACE):
718850 is_active = False
719851 )
720852 tool_model .save ()
721-
853+ if tool .get ('tool_type' ) == ToolType .WORKFLOW :
854+ tool ['id' ] = tool_id
855+ self .import_workflow_tools (tool , workspace_id = self .data .get ('workspace_id' ), user_id = user_id )
722856 # 自动授权给创建者
723857 UserResourcePermissionSerializer (data = {
724858 'workspace_id' : self .data .get ('workspace_id' ),
0 commit comments