Skip to content

Commit 988a610

Browse files
authored
fix: [Application, Knowledge Base] Workflow Export Now Supports Exporting Tool Workflows (#4978)
1 parent 7e2f6fa commit 988a610

File tree

4 files changed

+98
-12
lines changed

4 files changed

+98
-12
lines changed

apps/application/flow/tools.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
@date:2024/6/6 15:15
77
@desc:
88
"""
9-
from tools.models import ToolRecord, Tool
9+
from tools.models import ToolRecord, Tool, ToolScope
1010
from maxkb.const import CONFIG
1111
from knowledge.models.knowledge_action import State
1212
from knowledge.models import File
@@ -891,7 +891,8 @@ def save_workflow_mapping(workflow, source_type, source_id, other_resource_mappi
891891
{(str(item.target_type) + str(item.target_id)): item for item in resource_mapping_list}.values())
892892

893893

894-
def get_tool_id_list(workflow):
894+
def get_tool_id_list(workflow, with_deep=False):
895+
from tools.models import ToolWorkflow, ToolType
895896
_result = []
896897
for node in workflow.get('nodes', []):
897898
if node.get('type') == 'tool-lib-node':
@@ -921,4 +922,34 @@ def get_tool_id_list(workflow):
921922
'node_data', {}).get('mcp_tool_id')
922923
if mcp_tool_id:
923924
_result.append(mcp_tool_id)
925+
if with_deep:
926+
workflow_list = QuerySet(Tool).filter(id__in=_result, tool_type=ToolType.WORKFLOW)
927+
tool_work_flow_list = QuerySet(ToolWorkflow).filter(tool_id__in=[wl.id for wl in workflow_list])
928+
for tool_work_flow in tool_work_flow_list:
929+
child_tool_id_list = get_child_tool_id_list(tool_work_flow.work_flow, [])
930+
for c in child_tool_id_list:
931+
_result.append(c)
924932
return _result
933+
934+
935+
def get_child_tool_id_list(work_flow, response):
936+
from tools.models import ToolWorkflow, ToolType
937+
tool_id_list = get_tool_id_list(work_flow, False)
938+
tool_id_list = [tool_id for tool_id in tool_id_list if
939+
len([r for r in response if r == tool_id]) == 0]
940+
tool_list = []
941+
if len(tool_id_list) > 0:
942+
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
943+
work_flow_tools = [tool for tool in tool_list if tool.tool_type == ToolType.WORKFLOW]
944+
if len(work_flow_tools) > 0:
945+
946+
work_flow_tool_dict = {tw.tool_id: tw for tw in
947+
QuerySet(ToolWorkflow).filter(tool_id__in=[t.id for t in work_flow_tools])}
948+
for tool in tool_list:
949+
response.append(str(tool.id))
950+
if tool.tool_type == ToolType.WORKFLOW:
951+
get_child_tool_id_list(work_flow_tool_dict.get(tool.id).work_flow, response)
952+
else:
953+
for tool in tool_list:
954+
response.append(str(tool.id))
955+
return response

apps/application/serializers/application.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from system_manage.models.resource_mapping import ResourceMapping
5656
from system_manage.serializers.resource_mapping_serializers import ResourceMappingSerializer
5757
from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
58-
from tools.models import Tool, ToolScope, ToolType
58+
from tools.models import Tool, ToolScope, ToolType, ToolWorkflow
5959
from tools.serializers.tool import ToolExportModelSerializer
6060
from trigger.models import TriggerTask, Trigger
6161
from users.models import User
@@ -93,6 +93,10 @@ def hand_node(node, update_tool_map):
9393
mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
9494
node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
9595
mcp_tool_id)
96+
if node.get('type') == 'tool-workflow-lib-node':
97+
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
98+
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id,
99+
tool_lib_id)
96100

97101

98102
class MKInstance:
@@ -628,6 +632,12 @@ def import_(self, instance: dict, is_import_tool, with_valid=True):
628632
if is_import_tool:
629633
if len(tool_model_list) > 0:
630634
QuerySet(Tool).bulk_create(tool_model_list)
635+
QuerySet(ToolWorkflow).bulk_create(
636+
[ToolWorkflow(workspace_id=workspace_id,
637+
work_flow=self.reset_workflow(tool.get('work_flow'), update_tool_map),
638+
tool_id=tool.get('id'))
639+
for
640+
tool in tool_list if tool.get('tool_type') == ToolType.WORKFLOW])
631641
UserResourcePermissionSerializer(data={
632642
'workspace_id': self.data.get('workspace_id'),
633643
'user_id': self.data.get('user_id'),
@@ -670,6 +680,15 @@ def to_tool(tool, workspace_id, user_id):
670680
folder_id=workspace_id,
671681
workspace_id=workspace_id)
672682

683+
@staticmethod
684+
def reset_workflow(work_flow, update_tool_map):
685+
for node in work_flow.get('nodes', []):
686+
hand_node(node, update_tool_map)
687+
if node.get('type') == 'loop-node':
688+
for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []):
689+
hand_node(n, update_tool_map)
690+
return work_flow
691+
673692
@staticmethod
674693
def to_application(application, workspace_id, user_id, update_tool_map, folder_id):
675694
work_flow = application.get('work_flow')
@@ -844,13 +863,16 @@ def export(self, with_valid=True):
844863
application_id = self.data.get('application_id')
845864
application = QuerySet(Application).filter(id=application_id).first()
846865
from application.flow.tools import get_tool_id_list
847-
tool_id_list = get_tool_id_list(application.work_flow)
866+
tool_id_list = get_tool_id_list(application.work_flow, True)
848867
if len(tool_id_list) > 0:
849868
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
850869
else:
851870
tool_list = QuerySet(Tool).filter(
852871
id__in=application.tool_ids + application.mcp_tool_ids + application.skill_tool_ids
853872
).exclude(scope=ToolScope.SHARED)
873+
tw_dict = {tw.tool_id: tw
874+
for tw in QuerySet(ToolWorkflow).filter(
875+
tool_id__in=[tool.id for tool in tool_list if tool.tool_type == ToolType.WORKFLOW])}
854876
# 如果是技能工具,则需要将code字段转换为文件内容的base64字符串
855877
for tool in tool_list:
856878
if tool.tool_type == ToolType.SKILL:
@@ -862,7 +884,7 @@ def export(self, with_valid=True):
862884
mk_instance = MKInstance(application_dict,
863885
[],
864886
'v2',
865-
[ToolExportModelSerializer(tool).data for tool in
887+
[self.to_tool_dict(tool, tw_dict) for tool in
866888
tool_list])
867889
application_pickle = pickle.dumps(mk_instance)
868890
response = HttpResponse(content_type='text/plain', content=application_pickle)
@@ -871,6 +893,12 @@ def export(self, with_valid=True):
871893
except Exception as e:
872894
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)
873895

896+
@staticmethod
897+
def to_tool_dict(tool, tool_workflow_dict):
898+
if tool.tool_type == ToolType.WORKFLOW:
899+
return {**ToolExportModelSerializer(tool).data, 'work_flow': tool_workflow_dict.get(tool.id).work_flow}
900+
return ToolExportModelSerializer(tool).data
901+
874902
@staticmethod
875903
def reset_application_version(application_version, application):
876904
update_field_dict = {
@@ -1329,7 +1357,7 @@ def batch_delete(self, instance: Dict, with_valid=True):
13291357
from trigger.serializers.trigger import TriggerModelSerializer
13301358

13311359
if with_valid:
1332-
BatchSerializer(data=instance).is_valid(model=Application,raise_exception=True)
1360+
BatchSerializer(data=instance).is_valid(model=Application, raise_exception=True)
13331361
self.is_valid(raise_exception=True)
13341362
id_list = instance.get("id_list")
13351363
workspace_id = self.data.get('workspace_id')

apps/knowledge/serializers/knowledge_workflow.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from system_manage.models import AuthTargetType
3939
from system_manage.models.resource_mapping import ResourceType
4040
from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
41-
from tools.models import Tool, ToolScope
41+
from tools.models import Tool, ToolScope, ToolType, ToolWorkflow
4242
from tools.serializers.tool import ToolExportModelSerializer
4343
from users.models import User
4444

@@ -64,6 +64,10 @@ def hand_node(node, update_tool_map):
6464
mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
6565
node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
6666
mcp_tool_id)
67+
if node.get('type') == 'tool-workflow-lib-node':
68+
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
69+
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id,
70+
tool_lib_id)
6771

6872

6973
class KnowledgeWorkflowModelSerializer(serializers.ModelSerializer):
@@ -356,6 +360,12 @@ def import_(self, instance: dict, is_import_tool, with_valid=True):
356360
if is_import_tool:
357361
if len(tool_model_list) > 0:
358362
QuerySet(Tool).bulk_create(tool_model_list)
363+
QuerySet(ToolWorkflow).bulk_create(
364+
[ToolWorkflow(workspace_id=workspace_id,
365+
work_flow=self.reset_workflow(tool.get('work_flow'), update_tool_map),
366+
tool_id=tool.get('id'))
367+
for
368+
tool in tool_list if tool.get('tool_type') == ToolType.WORKFLOW])
359369
UserResourcePermissionSerializer(data={
360370
'workspace_id': self.data.get('workspace_id'),
361371
'user_id': self.data.get('user_id'),
@@ -374,6 +384,15 @@ def to_knowledge_workflow(knowledge_workflow, update_tool_map):
374384
hand_node(n, update_tool_map)
375385
return work_flow
376386

387+
@staticmethod
388+
def reset_workflow(work_flow, update_tool_map):
389+
for node in work_flow.get('nodes', []):
390+
hand_node(node, update_tool_map)
391+
if node.get('type') == 'loop-node':
392+
for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []):
393+
hand_node(n, update_tool_map)
394+
return work_flow
395+
377396
@staticmethod
378397
def to_tool(tool, workspace_id, user_id):
379398
return Tool(id=tool.get('id'),
@@ -402,17 +421,20 @@ def export(self, with_valid=True):
402421
knowledge_workflow = QuerySet(KnowledgeWorkflow).filter(knowledge_id=knowledge_id).first()
403422
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
404423
from application.flow.tools import get_tool_id_list
405-
tool_id_list = get_tool_id_list(knowledge_workflow.work_flow)
424+
tool_id_list = get_tool_id_list(knowledge_workflow.work_flow, True)
406425
tool_list = []
407426
if len(tool_id_list) > 0:
408427
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
428+
tw_dict = {tw.tool_id: tw
429+
for tw in QuerySet(ToolWorkflow).filter(
430+
tool_id__in=[tool.id for tool in tool_list if tool.tool_type == ToolType.WORKFLOW])}
409431
knowledge_workflow_dict = KnowledgeWorkflowModelSerializer(knowledge_workflow).data
410432

411433
kbwf_instance = KBWFInstance(
412434
knowledge_workflow_dict,
413435
[],
414436
'v2',
415-
[ToolExportModelSerializer(tool).data for tool in tool_list]
437+
[self.to_tool_dict(tool, tw_dict) for tool in tool_list]
416438
)
417439
knowledge_workflow_pickle = pickle.dumps(kbwf_instance)
418440
response = HttpResponse(content_type='text/plain', content=knowledge_workflow_pickle)
@@ -421,6 +443,12 @@ def export(self, with_valid=True):
421443
except Exception as e:
422444
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)
423445

446+
@staticmethod
447+
def to_tool_dict(tool, tool_workflow_dict):
448+
if tool.tool_type == ToolType.WORKFLOW:
449+
return {**ToolExportModelSerializer(tool).data, 'work_flow': tool_workflow_dict.get(tool.id).work_flow}
450+
return ToolExportModelSerializer(tool).data
451+
424452
class Operate(serializers.Serializer):
425453
user_id = serializers.UUIDField(required=True, label=_('user id'))
426454
workspace_id = serializers.CharField(required=True, label=_('workspace id'))

apps/tools/serializers/tool.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ def one(self):
659659

660660
def get_child_tool_list(self, work_flow, response):
661661
from application.flow.tools import get_tool_id_list
662-
tool_id_list = get_tool_id_list(work_flow)
662+
tool_id_list = get_tool_id_list(work_flow, False)
663663
tool_id_list = [tool_id for tool_id in tool_id_list if
664664
len([r for r in response if r.get('id') == tool_id]) == 0]
665665
tool_list = []
@@ -1307,6 +1307,7 @@ def process():
13071307

13081308
return to_stream_response_simple(process())
13091309

1310+
13101311
class ToolBatchOperateSerializer(serializers.Serializer):
13111312
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
13121313

@@ -1365,8 +1366,6 @@ def batch_move(self, instance: Dict, with_valid=True):
13651366
return True
13661367

13671368

1368-
1369-
13701369
class ToolTreeSerializer(serializers.Serializer):
13711370
class Query(serializers.Serializer):
13721371
workspace_id = serializers.CharField(required=True, label=_('workspace id'))

0 commit comments

Comments
 (0)