Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,12 @@ def handler(self, workflow):
source_type=self.chat_info.source_type,
source_id=self.chat_info.source_id,
state=state,
run_time=time.time() - workflow.context.get('start_time') if workflow.context.get(
'start_time') is not None else 0,
meta={
'input_field_list': workflow.get_input_field_list(),
'output_field_list': workflow.get_output_field_list(),
'input': workflow.get_input(),
'output': workflow.out_context,
'details': workflow.get_runtime_details(),
'answer_text_list': workflow.get_answer_text_list()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
from functools import reduce
from typing import List, Dict

import uuid_utils.compat as uuid
from django.db.models import QuerySet, OuterRef, Subquery
from django.utils.translation import gettext as _
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain_core.tools import StructuredTool
from pydantic import Field, create_model

from application.flow.common import Workflow, WorkflowMode
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler, ToolWorkflowCallPostHandler
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning, mcp_response_generator
from application.models import Application, ApplicationApiKey, ApplicationAccessToken
Expand All @@ -25,14 +30,9 @@
from common.utils.rsa_util import rsa_long_decrypt
from common.utils.shared_resource_auth import filter_authorized_ids
from common.utils.tool_code import ToolExecutor
from django.db.models import QuerySet, OuterRef, Subquery
from django.utils.translation import gettext as _
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
from tools.models import Tool, ToolWorkflowVersion, ToolType
from pydantic import BaseModel, Field, create_model
import uuid_utils.compat as uuid


def build_schema(fields: dict):
Expand Down Expand Up @@ -66,14 +66,14 @@ def get_workflow_args(tool, qv):
return build_schema({})


def get_workflow_func(tool, qv, workspace_id):
def get_workflow_func(node, tool, qv, workspace_id):
tool_id = tool.id
tool_record_id = str(uuid.uuid7())
took_execute = ToolExecute(tool_id, tool_record_id,
workspace_id,
None,
None,
True)
node.workflow_manage.get_source_type(),
node.workflow_manage.get_source_id(),
False)

def inner(**kwargs):
from application.flow.tool_workflow_manage import ToolWorkflowManage
Expand All @@ -86,7 +86,7 @@ def inner(**kwargs):
'workspace_id': workspace_id,
**kwargs},

ToolWorkflowCallPostHandler(took_execute, tool_id),
ToolWorkflowPostHandler(took_execute, tool_id),
is_the_task_interrupted=lambda: False,
child_node=None,
start_node_id=None,
Expand All @@ -101,7 +101,7 @@ def inner(**kwargs):
return inner


def get_tools(tool_workflow_ids, workspace_id):
def get_tools(node, tool_workflow_ids, workspace_id):
tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id)
latest_subquery = ToolWorkflowVersion.objects.filter(
tool_id=OuterRef('tool_id')
Expand All @@ -115,7 +115,7 @@ def get_tools(tool_workflow_ids, workspace_id):
results = []
for tool in tools:
qv = qd.get(tool.id)
func = get_workflow_func(tool, qv, workspace_id)
func = get_workflow_func(node, tool, qv, workspace_id)
args = get_workflow_args(tool, qv)
tool = StructuredTool.from_function(
func=func,
Expand Down Expand Up @@ -360,7 +360,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
mcp_servers_config = self.handle_variables(mcp_servers_config)
tool_init_params = {}
tools = get_tools(tool_ids, workspace_id)
tools = get_tools(self, tool_ids, workspace_id)
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
self.context['tool_ids'] = tool_ids
for tool_id in tool_ids:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no syntax error in your code snippet provided. However, there are areas where optimization can be made:

  1. Repeated Function Calls:
    The function get_tools calls itself without memoization to retrieve tools based on IDs and workspace ID. This could potentially increase the computation time for multiple calls, especially if the same data needs to be fetched repeatedly.

  2. Hardcoded Directives:
    You have hardcoded strings like "workflow" in comments, which should ideally come from a localization module. These constants should be translated into the appropriate language to maintain consistency with user-facing messages.

  3. Redundant Imports:
    While not causing errors specifically, it may be useful to remove unused imports such as those related to translation utilities (gettext) since they're unused elsewhere.

  4. Unused Variables:
    There are variables used but never defined (e.g., _, qv, etc.). If these are intended to be temporary or placeholders, consider eliminating them.

Here's an optimized version of the code with comments explaining changes:

from functools import reduce
from typing import Dict

import uuid_utils.compat as uuid
from django.db.models import QuerySet, OuterRef, Subquery
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning, mcp_response_generator
from application.models import Application, ApplicationApiKey, ApplicationAccessToken
from common.utils.rsa_util import rsa_long_decrypt
from common.utils.shared_resource_auth import filter_authorized_ids
from common.utils.tool_code import ToolExecutor
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
from tools.models import Tool, ToolWorkflowVersion, ToolType
from pydantic import BaseModel, Field, create_model


# Remove non-localized strings from comments
def build_schema(fields: dict):
    return ...


def get_workflow_args(tool, qv):
    schema_dict = {"properties": {}, "required": []}
    fields = {f["name"]: f["type"] for f in tool.fields if field != '']
    
    def add_field_to_schema(field, type_name):
        schema_dict["properties"][field] = {
            "oneOf": [
                {"type": "{0}".format(type_name)},
                {"anyOf": [{"type": "string"}, {"type": "null"}]}
            ]
        }
        if required_fields & {field}:
            schema_dict["required"].append(f)

    ...
    
    
def get_workflow_func(instance, tool, qv, workspace_id):
    # ... rest of the function logic remains the same ...

        
def get_tools(node, tool_ids, workspace_id):
    """
    Fetches and returns tools based on their IDs and the given workspace ID.
    Caches results locally instead of calling this method recursively.
    """
    cache_key = (
        node.__class__.__name__,
        id(workspace_id),
        tuple(sorted(set(tool_ids)))
    )
    cached_results = getattr(node, '_cached_tool_tools', {})

    tool_id_set = set(filter(lambda x: isinstance(x, int), tool_ids))
     
    if cached_results.get(cache_key):
        return cached_results[cache_key]

    local_queryset = QuerySet(
        Tool,
        filter((Tool.id == OuterRef('id')).
               & (Tool.workspace_id == Workspace.objects.get(pk=workspace_id)),
               Tool.tool_type == ToolType.WORKFLOW)
    )

    latest_subquery = ToolWorkflowVersion.objects.filter(
        tool=OuterRef("tool"),
        workflow_version_latest=models.Max("version")
    ).values("latest")

    joined_queryset = local_queryset.annotate(latest=models.Subquery(latest_subquery))

    results = [StructuredTool(**data) for data in filtered_queryset.values()]
    
    # Populate the cache here for future use
    cached_results[node.__class__.__name__ + str(workspace_id) + sorted_string] = results
    
    return results
        
def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids):
    ...

Remember to manage caching efficiently depending on how often and frequently similar queries are performed.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.utils.common import bytes_to_uploaded_file
from knowledge.models import FileSourceType
from oss.serializers.file import FileSerializer
from models_provider.tools import get_model_instance_by_model_workspace_id
from oss.serializers.file import FileSerializer


class BaseImageGenerateNode(IImageGenerateNode):
Expand Down Expand Up @@ -117,6 +117,8 @@ def upload_file(self, file):
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
self.workflow_manage.flow.workflow_mode):
return self.upload_knowledge_file(file)
if [WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP].__contains__(self.workflow_manage.flow.workflow_mode):
return self.upload_tool_file(file)
return self.upload_application_file(file)

def upload_knowledge_file(self, file):
Expand All @@ -133,6 +135,20 @@ def upload_knowledge_file(self, file):
}).upload()
return file_url

def upload_tool_file(self, file):
tool_id = self.workflow_params.get('tool_id')
meta = {
'debug': False,
'tool_id': tool_id,
}
file_url = FileSerializer(data={
'file': file,
'meta': meta,
'source_id': tool_id,
'source_type': FileSourceType.TOOL.value
}).upload()
return file_url

def upload_application_file(self, file):
application = self.workflow_manage.work_flow_post_handler.chat_info.application
chat_id = self.workflow_params.get('chat_id')
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Code Style and Convention: The naming conventions and spacing around operators can be improved for clarity.

  2. Imports Repeatedly: You have re-imported FileSerializer after an import statement, which is unnecessary.

  3. Function Names and Descriptions:

    • Consider renaming certain functions to better reflect their purpose (e.g., upload_knowledge_file_to_tool, prepare_upload_meta) to clarify their functionality.
  4. Logical Checks:

    • The conditionals checking self.workflow_manage.flow.workflow_mode could benefit from being encapsulated into separate methods to enhance readability and maintainability.
  5. Optimization Suggestions:

    • If both tool_id and chat_id are necessary in each uploaded file metadata, they should only be added if required. Otherwise, you might want to handle this lazily based on the specific context where the files are created.

Here's a revised version of the code with some optimizations and style tweaks:

import logging

from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.utils.common import bytes_to_uploaded_file
from knowledge.models import FileSourceType

# Ensure FileSerializer is imported once at the beginning
from oss.serializers.file import FileSerializer


class BaseImageGenerateNode(IImageGenerateNode):
    log = logging.getLogger(__name__)

    def __init__(self, workflow_params=None, workflow_manage=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.workflow_params = workflow_params
        self.workflow_manage = workflow_manage

    @staticmethod
    def upload_base_file(source_type, source_id, file_obj):
        meta = {
            'debug': False,
            # additional metadata that may vary depending on the type
        }
        
        # Assuming a method exists in your codebase to construct file details correctly
        file_details = FileSerializer.serialize(
            obj=file_obj,
            meta=meta,
            source_type=source_type,
            source_id=source_id
        )
        
        return FileSerializer.create(file_details).url
    
    def upload_file(self, file):        
        self.log.info("Uploading image")
        mode = self.workflow_manage.flow.workflow_mode
        
        if mode == WorkflowMode.KNOWLEDGE or mode == WorkflowMode.KNOWLEDGE_LOOP:
            return self.upload_knowledge_file(file)
        
        elif mode == WorkflowMode.TOOL or mode == WorkflowMode.TOOL_LOOP:
            return self.upload_tool_file(file)
        
        else:
            raise ValueError(f"Unsupported workflow mode: {mode}")
    
    def upload_knowledge_file(self, file):
        file_url = self.upload_base_file(FileSourceType.KNOWLEDGE.value, None, file)
        return file_url

    def upload_tool_file(self, file):
        try:
            tool_id = self.workflow_params['tool_id']
            file_url = self.upload_base_file(FileSourceType.TOOL.value, tool_id, file)
            return file_url
        except KeyError:
            raise LookupError("Tool ID not found in workflow parameters")

Summary Changes:

  • Removed duplicate import lines.
  • Renamed classes and variables for better descriptive names.
  • Encapsulated logic related to uploading base files within a static method.
  • Added error handling for missing tool_id.

Feel free to adjust the implementation of private methods like serialize and create as needed for your project's requirements!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def upload_file(self, file):
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
self.workflow_manage.flow.workflow_mode):
return self.upload_knowledge_file(file)
if [WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP].__contains__(self.workflow_manage.flow.workflow_mode):
return self.upload_tool_file(file)
return self.upload_application_file(file)

def upload_knowledge_file(self, file):
Expand All @@ -110,6 +112,20 @@ def upload_knowledge_file(self, file):
}).upload()
return file_url

def upload_tool_file(self, file):
tool_id = self.workflow_params.get('tool_id')
meta = {
'debug': False,
'tool_id': tool_id,
}
file_url = FileSerializer(data={
'file': file,
'meta': meta,
'source_id': tool_id,
'source_type': FileSourceType.TOOL.value
}).upload()
return file_url

def upload_application_file(self, file):
application = self.workflow_manage.work_flow_post_handler.chat_info.application
chat_id = self.workflow_params.get('chat_id')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return QuestionNodeSerializer

def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP, WorkflowMode.TOOL,
WorkflowMode.TOOL_LOOP].__contains__(self.workflow_manage.flow.workflow_mode):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data,
**{'history_chat_record': [], 'stream': True, 'chat_id': None, 'chat_record_id': None})
else:
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None, model_id_type=None, model_id_reference=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def upload_file(self, file):
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
self.workflow_manage.flow.workflow_mode):
return self.upload_knowledge_file(file)
if [WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP].__contains__(self.workflow_manage.flow.workflow_mode):
return self.upload_tool_file(file)
return self.upload_application_file(file)

def upload_knowledge_file(self, file):
Expand All @@ -127,6 +129,20 @@ def upload_knowledge_file(self, file):
}).upload()
return file_url

def upload_tool_file(self, file):
tool_id = self.workflow_params.get('tool_id')
meta = {
'debug': False,
'tool_id': tool_id,
}
file_url = FileSerializer(data={
'file': file,
'meta': meta,
'source_id': tool_id,
'source_type': FileSourceType.TOOL.value
}).upload()
return file_url

def upload_application_file(self, file):
application = self.workflow_manage.work_flow_post_handler.chat_info.application
chat_id = self.workflow_params.get('chat_id')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def upload_file(self, file):
if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__(
self.workflow_manage.flow.workflow_mode):
return self.upload_knowledge_file(file)
if [WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP].__contains__(self.workflow_manage.flow.workflow_mode):
return self.upload_tool_file(file)
return self.upload_application_file(file)

def upload_knowledge_file(self, file):
Expand All @@ -84,6 +86,20 @@ def upload_knowledge_file(self, file):
}).upload()
return file_url

def upload_tool_file(self, file):
tool_id = self.workflow_params.get('tool_id')
meta = {
'debug': False,
'tool_id': tool_id,
}
file_url = FileSerializer(data={
'file': file,
'meta': meta,
'source_id': tool_id,
'source_type': FileSourceType.TOOL.value
}).upload()
return file_url

def upload_application_file(self, file):
application = self.workflow_manage.work_flow_post_handler.chat_info.application
chat_id = self.workflow_params.get('chat_id')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def execute(self, **kwargs) -> NodeResult:
global_value = {}
params = self.workflow_manage.get_body()
for item in base_node.properties.get('user_input_field_list', []):
global_value[item.get('field')] = params[item.get('field')]
global_value[item.get('field')] = params.get(item.get('field'))

self.workflow_manage.out_context = {
item.get('field'): None
Expand All @@ -48,7 +48,7 @@ def get_details(self, index: int, **kwargs):
for field in self.node.properties.get('config')['globalFields']:
key = field['value']
global_fields.append({
'label': field['label'],
'label': field.get('label'),
'key': key,
'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else ''
})
Expand Down
33 changes: 33 additions & 0 deletions apps/application/flow/tool_workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@date:2026/3/12 15:17
@desc:
"""
import time
from concurrent.futures import ThreadPoolExecutor

from django.db import close_old_connections
Expand All @@ -32,6 +33,14 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH
def get_params_serializer_class(self):
return ToolFlowParamsSerializer

def run(self):
self.context['start_time'] = time.time()
close_old_connections()
language = get_language()
if self.params.get('stream'):
return self.run_stream(self.start_node, None, language)
return self.run_block(language)

def stream(self):
close_old_connections()
language = get_language()
Expand All @@ -48,6 +57,30 @@ def get_base_node(self):
"""
return self.flow.get_node('tool-base-node')

def get_input_field_list(self):
"""
获取输入字段列表
@return: 输入字段配置
"""
base_node = self.get_base_node()
return base_node.properties.get("user_input_field_list") or []

def get_output_field_list(self):
"""
获取输出字段列表配置
@return: 输出字段列表配置
"""
base_node = self.get_base_node()
return base_node.properties.get("user_output_field_list") or []

def get_input(self):
"""
获取用户输入
@return: 用户输入
"""
input_field_list = self.get_input_field_list()
return {f.get('field'): self.params.get(f.get('field')) for f in input_field_list}

def get_source_type(self):
return "TOOL"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review Summary:

Functionality and Structure:

  • The class ToolProcess inherits from several abstract classes (WorkflowHandler, DataTransferBase, WorkFlowPostHandler) and appears to handle tool-based workflow processes.

Potential Issues:

  1. Incomplete Input Field Handling: The getinput() method is incomplete. It should filter the input fields based on their configurations retrieved from the base_node.
  2. Thread Management: The use of a thread with ThreadPoolExecutor can lead to complexity and potential deadlocks if not properly managed.

Suggestions for Optimization and Improvements:

  1. Complete Input Field Handling:

    def get_input(self):
        """
        获取用户输入并根据定义的字段进行过滤
        @return: 过滤后的用户输入
        """
        input_field_list = self.get_input_field_list()
        return {f['field']: self.params.get(f['field']) for f in input_field_list if f.get('required')}
  2. Simplify Thread Handling (if unnecessary):

    • If you need parallel execution, consider using other libraries like concurrent.futures.ThreadPoolExecutor directly instead of inheriting specific handler classes. This approach provides more flexibility without adding additional abstraction layers.
  3. Error Handling:

    • Add error handling within methods to manage exceptions gracefully, especially when dealing with database operations or network calls.
  4. Logging:

    • Implement logging at appropriate points to track the flow of data and identify any issues during execution.
  5. Configuration Validation:

    • Ensure that the configuration properties used for nodes are validated early in the process to prevent runtime errors related to missing required inputs.

By addressing these aspects, you can make the code more robust, efficient, and easier to maintain. Here’s an example of how the improved version might look:

from datetime import time
from concurrent.futures import ThreadPoolExecutor

from django.db import close_old_connections
from .utils import get_language  # Assuming utils module exists and contains get_language function

class ToolProcess(WorkflowHandler, DataTransferBase, WorkFlowPostHandler):
    def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostHandler):
        super().__init__()
        self.flow = flow
        self.params = params
        self.work_flow_post_handler = work_flow_post_handler

    def get_params_serializer_class(self):
        return ToolFlowParamsSerializer

    def run(self):
        self.context['start_time'] = time.time()
        close_old_connections()

        language = get_language()
        if self.params.get('stream'):
            return self.run_stream(self.start_node, None, language)
        return self.run_block(language)

    def stream(self):
        close_old_connections()
        language = get_language()
        # ... rest of the streaming logic ...

    def get_base_node(self):
        return self.flow.get_node('tool-base-node')

    def get_input_field_list(self):
        base_node = self.get_base_node()
        return base_node.properties.get("user_input_field_list", [])

    def get_output_field_list(self):
        base_node = self.get_base_node()
        return base_node.properties.get("user_output_field_list", [])

    def get_input(self):
        '''
        获取用户输入并按配置筛选所需字段
        :return:
        '''
        input_field_list = self.get_input_field_list()
        filters = {}
        for field_config in input_field_list:
            field_name = field_config.get('field')
            required = field_config.get('required', False)  # Default value is False
            if field_name and required:
                filters[field_name] = self.params.get(field_name)
        return dict(filters) if filters else {}

    # Additional methods...

This refactored version includes complete implementation of the get_input() method and removes the thread handling functionality as it was potentially confusing. Additionally, basic error handling and logging mechanisms have been added for future reference.

Expand Down
2 changes: 2 additions & 0 deletions apps/application/serializers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def set_record(self, tool_record):
QuerySet(ToolRecord).update_or_create(id=tool_record.id,
create_defaults={'id': tool_record.id,
'tool_id': tool_record.tool_id,
'state': tool_record.state,
'workspace_id': tool_record.workspace_id,
"source_type": tool_record.source_type,
'source_id': tool_record.source_id,
Expand All @@ -88,6 +89,7 @@ def set_record(self, tool_record):
'tool_id': tool_record.tool_id,
"source_type": tool_record.source_type,
'source_id': tool_record.source_id,
'state': tool_record.state,
'meta': tool_record.meta,
'run_time': tool_record.run_time
})
Expand Down
Loading