Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions apps/application/flow/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@date:2024/6/6 15:15
@desc:
"""
from tools.models import ToolRecord, Tool
from tools.models import ToolRecord, Tool, ToolScope
from maxkb.const import CONFIG
from knowledge.models.knowledge_action import State
from knowledge.models import File
Expand Down Expand Up @@ -891,7 +891,8 @@ def save_workflow_mapping(workflow, source_type, source_id, other_resource_mappi
{(str(item.target_type) + str(item.target_id)): item for item in resource_mapping_list}.values())


def get_tool_id_list(workflow):
def get_tool_id_list(workflow, with_deep=False):
from tools.models import ToolWorkflow, ToolType
_result = []
for node in workflow.get('nodes', []):
if node.get('type') == 'tool-lib-node':
Expand Down Expand Up @@ -921,4 +922,34 @@ def get_tool_id_list(workflow):
'node_data', {}).get('mcp_tool_id')
if mcp_tool_id:
_result.append(mcp_tool_id)
if with_deep:
workflow_list = QuerySet(Tool).filter(id__in=_result, tool_type=ToolType.WORKFLOW)
tool_work_flow_list = QuerySet(ToolWorkflow).filter(tool_id__in=[wl.id for wl in workflow_list])
for tool_work_flow in tool_work_flow_list:
child_tool_id_list = get_child_tool_id_list(tool_work_flow.work_flow, [])
for c in child_tool_id_list:
_result.append(c)
return _result


def get_child_tool_id_list(work_flow, response):
from tools.models import ToolWorkflow, ToolType
tool_id_list = get_tool_id_list(work_flow, False)
tool_id_list = [tool_id for tool_id in tool_id_list if
len([r for r in response if r == tool_id]) == 0]
tool_list = []
if len(tool_id_list) > 0:
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
work_flow_tools = [tool for tool in tool_list if tool.tool_type == ToolType.WORKFLOW]
if len(work_flow_tools) > 0:

work_flow_tool_dict = {tw.tool_id: tw for tw in
QuerySet(ToolWorkflow).filter(tool_id__in=[t.id for t in work_flow_tools])}
for tool in tool_list:
response.append(str(tool.id))
if tool.tool_type == ToolType.WORKFLOW:
get_child_tool_id_list(work_flow_tool_dict.get(tool.id).work_flow, response)
else:
for tool in tool_list:
response.append(str(tool.id))
return response
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Imports and Variable Naming:

    • The variable name _result in get_tool_id_list() is misleading because it suggests that this variable will store results, but it ends up being an empty list.
    • In save_workflow_mapping(), the comment about filtering out files using a specific key ((str(item.target_type) + str(item.target_id))) may be redundant since the actual filter logic does not include such functionality.
  2. Logic of get_tool_id_list():

    • There's no clear separation between different types of nodes (like tools and non-tools) when collecting IDs from workflows.
  3. Improvements for Recursive Handling:

    • Using recursion might make sense for exploring all children in the workflow tree structure to collect tool ID lists. Consider renaming variables like response to reflect their intended purpose or use more descriptive names.
  4. Edge Cases in Functions:

    • Implement checks for edge cases where there are no matching records or unexpected data structures.

Here’s an optimized version incorporating some suggested changes:

def get_tool_id_list(workflow, with_deep=False):
    _result = set()
    
    if 'nodes' in workflow.get('nodes'):
        for node in workflow['nodes']:
            if node.get('type') == 'tool-lib-node':
                mcp_tool_id = node.get('node_data', {}).get('mcp_tool_id')
                if mcp_tool_id:
                    _result.add(mcp_tool_id)
    
    if with_deep:
        # Assuming workflow_map includes mappings for nested tools
        from maxkb.map.model.workflowMap import WorkflowMapping
        
        # This loop assumes every record maps to multiple related ids based on target_key_str format
        for mapping in map.records():
            for record in mapping.related_records_with_target_type(str(node.target_type)):
                _result.update(map.target_ids_of_record(record, str(mapping.target_key_str)))
        
        # Recursively fetch additional workflow tools
        for id in _result:
            wf_tools = ToolWorkflowService.fetch_related_workflows_for_tool(id)
            for wf_tool in wf_tools:
                child_tool_id_list = get_tool_id_list(wf_tool.work_flow, True)
                _result.update(child_tool_id_list)
                
    return sorted(_result)

# Similar improvements can be applied similarly to `get_child_tool_id_list`

This version uses sets to avoid duplicate entries, which improves efficiency especially when dealing with large datasets. It also outlines general guidelines that could guide further optimizations depending on actual usage patterns and performance needs.

36 changes: 32 additions & 4 deletions apps/application/serializers/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from system_manage.models.resource_mapping import ResourceMapping
from system_manage.serializers.resource_mapping_serializers import ResourceMappingSerializer
from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
from tools.models import Tool, ToolScope, ToolType
from tools.models import Tool, ToolScope, ToolType, ToolWorkflow
from tools.serializers.tool import ToolExportModelSerializer
from trigger.models import TriggerTask, Trigger
from users.models import User
Expand Down Expand Up @@ -93,6 +93,10 @@ def hand_node(node, update_tool_map):
mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
mcp_tool_id)
if node.get('type') == 'tool-workflow-lib-node':
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id,
tool_lib_id)


class MKInstance:
Expand Down Expand Up @@ -628,6 +632,12 @@ def import_(self, instance: dict, is_import_tool, with_valid=True):
if is_import_tool:
if len(tool_model_list) > 0:
QuerySet(Tool).bulk_create(tool_model_list)
QuerySet(ToolWorkflow).bulk_create(
[ToolWorkflow(workspace_id=workspace_id,
work_flow=self.reset_workflow(tool.get('work_flow'), update_tool_map),
tool_id=tool.get('id'))
for
tool in tool_list if tool.get('tool_type') == ToolType.WORKFLOW])
UserResourcePermissionSerializer(data={
'workspace_id': self.data.get('workspace_id'),
'user_id': self.data.get('user_id'),
Expand Down Expand Up @@ -670,6 +680,15 @@ def to_tool(tool, workspace_id, user_id):
folder_id=workspace_id,
workspace_id=workspace_id)

@staticmethod
def reset_workflow(work_flow, update_tool_map):
for node in work_flow.get('nodes', []):
hand_node(node, update_tool_map)
if node.get('type') == 'loop-node':
for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []):
hand_node(n, update_tool_map)
return work_flow

@staticmethod
def to_application(application, workspace_id, user_id, update_tool_map, folder_id):
work_flow = application.get('work_flow')
Expand Down Expand Up @@ -844,13 +863,16 @@ def export(self, with_valid=True):
application_id = self.data.get('application_id')
application = QuerySet(Application).filter(id=application_id).first()
from application.flow.tools import get_tool_id_list
tool_id_list = get_tool_id_list(application.work_flow)
tool_id_list = get_tool_id_list(application.work_flow, True)
if len(tool_id_list) > 0:
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
else:
tool_list = QuerySet(Tool).filter(
id__in=application.tool_ids + application.mcp_tool_ids + application.skill_tool_ids
).exclude(scope=ToolScope.SHARED)
tw_dict = {tw.tool_id: tw
for tw in QuerySet(ToolWorkflow).filter(
tool_id__in=[tool.id for tool in tool_list if tool.tool_type == ToolType.WORKFLOW])}
# 如果是技能工具,则需要将code字段转换为文件内容的base64字符串
for tool in tool_list:
if tool.tool_type == ToolType.SKILL:
Expand All @@ -862,7 +884,7 @@ def export(self, with_valid=True):
mk_instance = MKInstance(application_dict,
[],
'v2',
[ToolExportModelSerializer(tool).data for tool in
[self.to_tool_dict(tool, tw_dict) for tool in
tool_list])
application_pickle = pickle.dumps(mk_instance)
response = HttpResponse(content_type='text/plain', content=application_pickle)
Expand All @@ -871,6 +893,12 @@ def export(self, with_valid=True):
except Exception as e:
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)

@staticmethod
def to_tool_dict(tool, tool_workflow_dict):
if tool.tool_type == ToolType.WORKFLOW:
return {**ToolExportModelSerializer(tool).data, 'work_flow': tool_workflow_dict.get(tool.id).work_flow}
return ToolExportModelSerializer(tool).data

@staticmethod
def reset_application_version(application_version, application):
update_field_dict = {
Expand Down Expand Up @@ -1329,7 +1357,7 @@ def batch_delete(self, instance: Dict, with_valid=True):
from trigger.serializers.trigger import TriggerModelSerializer

if with_valid:
BatchSerializer(data=instance).is_valid(model=Application,raise_exception=True)
BatchSerializer(data=instance).is_valid(model=Application, raise_exception=True)
self.is_valid(raise_exception=True)
id_list = instance.get("id_list")
workspace_id = self.data.get('workspace_id')
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks mostly correct but contains a few areas that could be improved:

  1. Line Length: The line QuerySet(ToolWorkflow).bulk_create([...]) exceeds the maximum recommended line length of 80 characters, which can make it harder to read.

  2. Tool Workflow Reset Method: The reset_workflow method should use the provided update_tool_map on all nodes of both main flow and loop bodies. However, there's a small mistake where the condition to check for a workflow type (node.get('type') == 'workflow') might need to include 'tool-workflow-lib-node'.

  3. Export Functionality: In the export function, the logic for handling skill tools should be more flexible. Currently, it assumes these tool objects have a file_content attribute that needs conversion into base64 data. Consider checking if file_content exists before attempting the conversion.

Here are some specific improvements to suggest:

Line Length Improvement

QuerySet(ToolWorkflow).bulk_create([ToolWorkflow(workspace_id=ws_id, work_flow=self.reset_workflow(twf, update_tool_map), tool_id=tw_id)
                                    for tw in tool_list if tw.tool_type == ToolType.WORKFLOW])

Correction in Workflow Type Check Inside reset_workflow

if node.get('type') in ['workflow', 'tool-workflow-lib-node']:
    ...

Enhanced Export Handling for Skill Tools

Ensure skill tools have a file content attribute before converting it to base64.

for tool in tool_list:
    if tool.tool_type == ToolType.SKILL:
        if not hasattr(tool, 'file_content'):
            # Additional error logging or default value setting
            pass
        # Convert file content to base64 string here...
        ...

def convert_file_to_base64(file_content):
    return base64.b64encode(file_content.read()).decode('utf-8')

These changes aim to improve readability while maintaining functionality. If you want further optimizations or additional considerations, let me know!

Expand Down
34 changes: 31 additions & 3 deletions apps/knowledge/serializers/knowledge_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from system_manage.models import AuthTargetType
from system_manage.models.resource_mapping import ResourceType
from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
from tools.models import Tool, ToolScope
from tools.models import Tool, ToolScope, ToolType, ToolWorkflow
from tools.serializers.tool import ToolExportModelSerializer
from users.models import User

Expand All @@ -64,6 +64,10 @@ def hand_node(node, update_tool_map):
mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
mcp_tool_id)
if node.get('type') == 'tool-workflow-lib-node':
tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id,
tool_lib_id)


class KnowledgeWorkflowModelSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -356,6 +360,12 @@ def import_(self, instance: dict, is_import_tool, with_valid=True):
if is_import_tool:
if len(tool_model_list) > 0:
QuerySet(Tool).bulk_create(tool_model_list)
QuerySet(ToolWorkflow).bulk_create(
[ToolWorkflow(workspace_id=workspace_id,
work_flow=self.reset_workflow(tool.get('work_flow'), update_tool_map),
tool_id=tool.get('id'))
for
tool in tool_list if tool.get('tool_type') == ToolType.WORKFLOW])
UserResourcePermissionSerializer(data={
'workspace_id': self.data.get('workspace_id'),
'user_id': self.data.get('user_id'),
Expand All @@ -374,6 +384,15 @@ def to_knowledge_workflow(knowledge_workflow, update_tool_map):
hand_node(n, update_tool_map)
return work_flow

@staticmethod
def reset_workflow(work_flow, update_tool_map):
for node in work_flow.get('nodes', []):
hand_node(node, update_tool_map)
if node.get('type') == 'loop-node':
for n in node.get('properties', {}).get('node_data', {}).get('loop_body', {}).get('nodes', []):
hand_node(n, update_tool_map)
return work_flow

@staticmethod
def to_tool(tool, workspace_id, user_id):
return Tool(id=tool.get('id'),
Expand Down Expand Up @@ -402,17 +421,20 @@ def export(self, with_valid=True):
knowledge_workflow = QuerySet(KnowledgeWorkflow).filter(knowledge_id=knowledge_id).first()
knowledge = QuerySet(Knowledge).filter(id=knowledge_id).first()
from application.flow.tools import get_tool_id_list
tool_id_list = get_tool_id_list(knowledge_workflow.work_flow)
tool_id_list = get_tool_id_list(knowledge_workflow.work_flow, True)
tool_list = []
if len(tool_id_list) > 0:
tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
tw_dict = {tw.tool_id: tw
for tw in QuerySet(ToolWorkflow).filter(
tool_id__in=[tool.id for tool in tool_list if tool.tool_type == ToolType.WORKFLOW])}
knowledge_workflow_dict = KnowledgeWorkflowModelSerializer(knowledge_workflow).data

kbwf_instance = KBWFInstance(
knowledge_workflow_dict,
[],
'v2',
[ToolExportModelSerializer(tool).data for tool in tool_list]
[self.to_tool_dict(tool, tw_dict) for tool in tool_list]
)
knowledge_workflow_pickle = pickle.dumps(kbwf_instance)
response = HttpResponse(content_type='text/plain', content=knowledge_workflow_pickle)
Expand All @@ -421,6 +443,12 @@ def export(self, with_valid=True):
except Exception as e:
return result.error(str(e), response_status=status.HTTP_500_INTERNAL_SERVER_ERROR)

@staticmethod
def to_tool_dict(tool, tool_workflow_dict):
if tool.tool_type == ToolType.WORKFLOW:
return {**ToolExportModelSerializer(tool).data, 'work_flow': tool_workflow_dict.get(tool.id).work_flow}
return ToolExportModelSerializer(tool).data

class Operate(serializers.Serializer):
user_id = serializers.UUIDField(required=True, label=_('user id'))
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your code appears to be generally well-structured, but there are a few areas that could benefit from improvements:

Improvements:

  1. Static Methods for Workflow Resetting:

    • The reset_workflow method can be made static as it doesn't need access to an instance of KnowledgeWorkflowModelSerializer. This improves readability and makes it reusable.
  2. Tool Dictionary Creation:

    • In the _to_tool_dict method, ensure that the dictionary is correctly formatted regardless of whether the tool type is workflow or not. You might want to separate handling based on the tool type.
  3. Error Handling in Export Method:

    • Consider logging exceptions or adding more informative error messages in the handleExport method to help debug issues during exports.
  4. Consistent Use of Query Sets:

    • Ensure that all queries use appropriate query sets for better performance and security. For example, consider using .select_related() and .prefetch_related() where applicable.
  5. Validation in Importer:

    • Before creating objects, consider validating data integrity or constraints (e.g., unique keys).
  6. Code Comments:

    • Add comments for complex parts of the code to enhance readability for others who may maintain it.

Optimizations:

  • Bulk Operations:

    • Ensure that bulk operations like QuerySet.objects.bulk_create() are used properly. Make sure you have indexed fields where necessary to optimize database performance.
  • Logging:

    • Implement logging around critical sections of the code to trace its flow and identify any issues without causing service unavailability.

Here's your updated code snippet with some improvements applied:

@@ -38,7 +38,7 @@
 from system_manage.models import AuthTargetType
 from system_manage.models.resource_mapping import ResourceType
 from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
-from tools.models import Tool, ToolScope
+from tools.models import Tool, ToolScope, ToolType, ToolWorkflow
 from tools.serializers.tool import ToolExportModelSerializer
 from users.models import User

@@ -64,6 +64,10 @@ def hand_node(node, update_tool_map):
         mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
         node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
                                                                                              mcp_tool_id)
+    if node.get('type') == 'tool-workflow-lib-node':
+        tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
+        node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id,
+                                                                                             tool_lib_id)
 
 
 class KnowledgeWorkflowModelSerializer(serializers.ModelSerializer):
@@ -356,6 +360,12 @@ def import_(self, instance: dict, is_import_tool, with_valid=True):
             if is_import_tool:
                 if len(tool_model_list) > 0:
                     QuerySet(Tool).bulk_create(tool_model_list)
+                    QuerySet(ToolWorkflow).bulk_create(
+                        [ToolWorkflow(workspace_id=workspace_id,
+                                      work_flow=self._reset_workflow(tool.get('work_flow'), update_tool_map),
+                                      tool_id=tool.get('id'))
+                         for tool in tool_list if tool.get('tool_type') == ToolType.WORKFLOW])
+

These changes should improve the robustness and maintainability of your codebase.

Expand Down
5 changes: 2 additions & 3 deletions apps/tools/serializers/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def one(self):

def get_child_tool_list(self, work_flow, response):
from application.flow.tools import get_tool_id_list
tool_id_list = get_tool_id_list(work_flow)
tool_id_list = get_tool_id_list(work_flow, False)
tool_id_list = [tool_id for tool_id in tool_id_list if
len([r for r in response if r.get('id') == tool_id]) == 0]
tool_list = []
Expand Down Expand Up @@ -1307,6 +1307,7 @@ def process():

return to_stream_response_simple(process())


class ToolBatchOperateSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))

Expand Down Expand Up @@ -1365,8 +1366,6 @@ def batch_move(self, instance: Dict, with_valid=True):
return True




class ToolTreeSerializer(serializers.Serializer):
class Query(serializers.Serializer):
workspace_id = serializers.CharField(required=True, label=_('workspace id'))
Expand Down
Loading