-
Notifications
You must be signed in to change notification settings - Fork 2.8k
fix: tool workflow bugs #4962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: tool workflow bugs #4962
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,8 +10,8 @@ | |
| from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode | ||
| from common.utils.common import bytes_to_uploaded_file | ||
| from knowledge.models import FileSourceType | ||
| from oss.serializers.file import FileSerializer | ||
| from models_provider.tools import get_model_instance_by_model_workspace_id | ||
| from oss.serializers.file import FileSerializer | ||
|
|
||
|
|
||
| class BaseImageGenerateNode(IImageGenerateNode): | ||
|
|
@@ -117,6 +117,8 @@ def upload_file(self, file): | |
| if [WorkflowMode.KNOWLEDGE, WorkflowMode.KNOWLEDGE_LOOP].__contains__( | ||
| self.workflow_manage.flow.workflow_mode): | ||
| return self.upload_knowledge_file(file) | ||
| if [WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP].__contains__(self.workflow_manage.flow.workflow_mode): | ||
| return self.upload_tool_file(file) | ||
| return self.upload_application_file(file) | ||
|
|
||
| def upload_knowledge_file(self, file): | ||
|
|
@@ -133,6 +135,20 @@ def upload_knowledge_file(self, file): | |
| }).upload() | ||
| return file_url | ||
|
|
||
| def upload_tool_file(self, file): | ||
| tool_id = self.workflow_params.get('tool_id') | ||
| meta = { | ||
| 'debug': False, | ||
| 'tool_id': tool_id, | ||
| } | ||
| file_url = FileSerializer(data={ | ||
| 'file': file, | ||
| 'meta': meta, | ||
| 'source_id': tool_id, | ||
| 'source_type': FileSourceType.TOOL.value | ||
| }).upload() | ||
| return file_url | ||
|
|
||
| def upload_application_file(self, file): | ||
| application = self.workflow_manage.work_flow_post_handler.chat_info.application | ||
| chat_id = self.workflow_params.get('chat_id') | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here's a revised version of the code with some optimizations and style tweaks: import logging
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.utils.common import bytes_to_uploaded_file
from knowledge.models import FileSourceType
# Ensure FileSerializer is imported once at the beginning
from oss.serializers.file import FileSerializer
class BaseImageGenerateNode(IImageGenerateNode):
log = logging.getLogger(__name__)
def __init__(self, workflow_params=None, workflow_manage=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.workflow_params = workflow_params
self.workflow_manage = workflow_manage
@staticmethod
def upload_base_file(source_type, source_id, file_obj):
meta = {
'debug': False,
# additional metadata that may vary depending on the type
}
# Assuming a method exists in your codebase to construct file details correctly
file_details = FileSerializer.serialize(
obj=file_obj,
meta=meta,
source_type=source_type,
source_id=source_id
)
return FileSerializer.create(file_details).url
def upload_file(self, file):
self.log.info("Uploading image")
mode = self.workflow_manage.flow.workflow_mode
if mode == WorkflowMode.KNOWLEDGE or mode == WorkflowMode.KNOWLEDGE_LOOP:
return self.upload_knowledge_file(file)
elif mode == WorkflowMode.TOOL or mode == WorkflowMode.TOOL_LOOP:
return self.upload_tool_file(file)
else:
raise ValueError(f"Unsupported workflow mode: {mode}")
def upload_knowledge_file(self, file):
file_url = self.upload_base_file(FileSourceType.KNOWLEDGE.value, None, file)
return file_url
def upload_tool_file(self, file):
try:
tool_id = self.workflow_params['tool_id']
file_url = self.upload_base_file(FileSourceType.TOOL.value, tool_id, file)
return file_url
except KeyError:
raise LookupError("Tool ID not found in workflow parameters")Summary Changes:
Feel free to adjust the implementation of private methods like |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| @date:2026/3/12 15:17 | ||
| @desc: | ||
| """ | ||
| import time | ||
| from concurrent.futures import ThreadPoolExecutor | ||
|
|
||
| from django.db import close_old_connections | ||
|
|
@@ -32,6 +33,14 @@ def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostH | |
| def get_params_serializer_class(self): | ||
| return ToolFlowParamsSerializer | ||
|
|
||
| def run(self): | ||
| self.context['start_time'] = time.time() | ||
| close_old_connections() | ||
| language = get_language() | ||
| if self.params.get('stream'): | ||
| return self.run_stream(self.start_node, None, language) | ||
| return self.run_block(language) | ||
|
|
||
| def stream(self): | ||
| close_old_connections() | ||
| language = get_language() | ||
|
|
@@ -48,6 +57,30 @@ def get_base_node(self): | |
| """ | ||
| return self.flow.get_node('tool-base-node') | ||
|
|
||
| def get_input_field_list(self): | ||
| """ | ||
| 获取输入字段列表 | ||
| @return: 输入字段配置 | ||
| """ | ||
| base_node = self.get_base_node() | ||
| return base_node.properties.get("user_input_field_list") or [] | ||
|
|
||
| def get_output_field_list(self): | ||
| """ | ||
| 获取输出字段列表配置 | ||
| @return: 输出字段列表配置 | ||
| """ | ||
| base_node = self.get_base_node() | ||
| return base_node.properties.get("user_output_field_list") or [] | ||
|
|
||
| def get_input(self): | ||
| """ | ||
| 获取用户输入 | ||
| @return: 用户输入 | ||
| """ | ||
| input_field_list = self.get_input_field_list() | ||
| return {f.get('field'): self.params.get(f.get('field')) for f in input_field_list} | ||
|
|
||
| def get_source_type(self): | ||
| return "TOOL" | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code Review Summary:Functionality and Structure:
Potential Issues:
Suggestions for Optimization and Improvements:
By addressing these aspects, you can make the code more robust, efficient, and easier to maintain. Here’s an example of how the improved version might look: from datetime import time
from concurrent.futures import ThreadPoolExecutor
from django.db import close_old_connections
from .utils import get_language # Assuming utils module exists and contains get_language function
class ToolProcess(WorkflowHandler, DataTransferBase, WorkFlowPostHandler):
def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostHandler):
super().__init__()
self.flow = flow
self.params = params
self.work_flow_post_handler = work_flow_post_handler
def get_params_serializer_class(self):
return ToolFlowParamsSerializer
def run(self):
self.context['start_time'] = time.time()
close_old_connections()
language = get_language()
if self.params.get('stream'):
return self.run_stream(self.start_node, None, language)
return self.run_block(language)
def stream(self):
close_old_connections()
language = get_language()
# ... rest of the streaming logic ...
def get_base_node(self):
return self.flow.get_node('tool-base-node')
def get_input_field_list(self):
base_node = self.get_base_node()
return base_node.properties.get("user_input_field_list", [])
def get_output_field_list(self):
base_node = self.get_base_node()
return base_node.properties.get("user_output_field_list", [])
def get_input(self):
'''
获取用户输入并按配置筛选所需字段
:return:
'''
input_field_list = self.get_input_field_list()
filters = {}
for field_config in input_field_list:
field_name = field_config.get('field')
required = field_config.get('required', False) # Default value is False
if field_name and required:
filters[field_name] = self.params.get(field_name)
return dict(filters) if filters else {}
# Additional methods...This refactored version includes complete implementation of the |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no syntax error in your code snippet provided. However, there are areas where optimization can be made:
Repeated Function Calls:
The function
get_toolscalls itself without memoization to retrieve tools based on IDs and workspace ID. This could potentially increase the computation time for multiple calls, especially if the same data needs to be fetched repeatedly.Hardcoded Directives:
You have hardcoded strings like "workflow" in comments, which should ideally come from a localization module. These constants should be translated into the appropriate language to maintain consistency with user-facing messages.
Redundant Imports:
While not causing errors specifically, it may be useful to remove unused imports such as those related to translation utilities (
gettext) since they're unused elsewhere.Unused Variables:
There are variables used but never defined (e.g.,
_,qv, etc.). If these are intended to be temporary or placeholders, consider eliminating them.Here's an optimized version of the code with comments explaining changes:
Remember to manage caching efficiently depending on how often and frequently similar queries are performed.