Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning, mcp_response_generator
from application.flow.tools import Reasoning, mcp_response_generator, get_tools
from application.models import ApplicationChatUserStats, ChatUserType, Application, ApplicationApiKey, \
ApplicationAccessToken
from common.exception.app_exception import AppApiException
Expand All @@ -31,7 +31,7 @@
from common.utils.shared_resource_auth import filter_authorized_ids
from common.utils.tool_code import ToolExecutor
from models_provider.tools import get_model_instance_by_model_workspace_id
from tools.models import Tool
from tools.models import Tool, ToolType


def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
Expand Down Expand Up @@ -232,7 +232,7 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):

def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
application_ids, skill_tool_ids, mcp_output_enable, chat_model, message_list, agent_id,
chat_id):
chat_id, workspace_id):

mcp_servers_config = {}

Expand All @@ -252,10 +252,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}

tool_init_params = {}
tools = get_tools("APPLICATION", agent_id, tool_ids,
workspace_id)
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
self.context['tool_ids'] = tool_ids
for tool_id in tool_ids:
tool = QuerySet(Tool).filter(id=tool_id).first()
tool = QuerySet(Tool).filter(id=tool_id, tool_type=ToolType.CUSTOM).first()
if tool is None or tool.is_active is False:
continue
executor = ToolExecutor()
Expand Down Expand Up @@ -316,12 +318,12 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
})
mcp_servers_config['skills'] = skill_file_items

if len(mcp_servers_config) > 0:
if len(mcp_servers_config) > 0 or len(tools) > 0:
source_id = agent_id
source_type = 'APPLICATION'
return mcp_response_generator(
chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable,
tool_init_params, source_id, source_type, chat_id
tool_init_params, source_id, source_type, chat_id, tools
)

return None
Expand Down Expand Up @@ -372,7 +374,7 @@ def get_stream_result(self, message_list: List[BaseMessage],
mcp_result = self._handle_mcp_request(
mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
application_ids, skill_tool_ids, mcp_output_enable, chat_model,
message_list, agent_id, chat_id
message_list, agent_id, chat_id, workspace_id
)
if mcp_result:
return mcp_result, True
Expand Down Expand Up @@ -461,7 +463,7 @@ def get_block_result(self, message_list: List[BaseMessage],
mcp_result = self._handle_mcp_request(
mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
application_ids, skill_tool_ids, mcp_output_enable,
chat_model, message_list, application_id, chat_id
chat_model, message_list, application_id, chat_id, workspace_id
)
if mcp_result:
return mcp_result, True
Expand Down Expand Up @@ -496,7 +498,7 @@ def execute_block(self, message_list: List[BaseMessage],
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text,
mcp_tool_ids, mcp_servers, mcp_source,
tool_ids, application_ids, skill_tool_ids,workspace_id,
tool_ids, application_ids, skill_tool_ids, workspace_id,
mcp_output_enable, manage.context.get('application_id'),
chat_id)
if is_ai_chat:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,123 +9,25 @@
import json
import re
import time
import uuid
from functools import reduce
from typing import List, Dict

import uuid_utils.compat as uuid
from django.db.models import QuerySet, OuterRef, Subquery
from django.db.models import QuerySet
from django.utils.translation import gettext as _
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage, SystemMessage
from langchain_core.tools import StructuredTool
from pydantic import Field, create_model

from application.flow.common import Workflow, WorkflowMode
from application.flow.i_step_node import NodeResult, INode, ToolWorkflowPostHandler
from application.flow.common import WorkflowMode
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning, mcp_response_generator
from application.flow.tools import Reasoning, mcp_response_generator, get_tools
from application.models import Application, ApplicationApiKey, ApplicationAccessToken
from application.serializers.common import ToolExecute
from common.exception.app_exception import AppApiException
from common.utils.rsa_util import rsa_long_decrypt
from common.utils.shared_resource_auth import filter_authorized_ids
from common.utils.tool_code import ToolExecutor
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
from tools.models import Tool, ToolWorkflowVersion, ToolType


def build_schema(fields: dict):
return create_model("dynamicSchema", **fields)


def get_type(_type: str):
if _type == 'float':
return float
if _type == 'string':
return str
if _type == 'int':
return int
if _type == 'dict':
return dict
if _type == 'array':
return list
if _type == 'boolean':
return bool
return object


def get_workflow_args(tool, qv):
for node in qv.work_flow.get('nodes'):
if node.get('type') == 'tool-base-node':
input_field_list = node.get('properties').get('user_input_field_list')
return build_schema(
{field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
for field in input_field_list})

return build_schema({})


def get_workflow_func(node, tool, qv, workspace_id):
tool_id = tool.id
tool_record_id = str(uuid.uuid7())
took_execute = ToolExecute(tool_id, tool_record_id,
workspace_id,
node.workflow_manage.get_source_type(),
node.workflow_manage.get_source_id(),
False)

def inner(**kwargs):
from application.flow.tool_workflow_manage import ToolWorkflowManage
work_flow_manage = ToolWorkflowManage(
Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
{
'chat_record_id': tool_record_id,
'tool_id': tool_id,
'stream': True,
'workspace_id': workspace_id,
**kwargs},

ToolWorkflowPostHandler(took_execute, tool_id),
is_the_task_interrupted=lambda: False,
child_node=None,
start_node_id=None,
start_node_data=None,
chat_record=None
)
res = work_flow_manage.run()
for r in res:
pass
return work_flow_manage.out_context

return inner


def get_tools(node, tool_workflow_ids, workspace_id):
tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id)
latest_subquery = ToolWorkflowVersion.objects.filter(
tool_id=OuterRef('tool_id')
).order_by('-create_time')

qs = ToolWorkflowVersion.objects.filter(
tool_id__in=[t.id for t in tools],
id=Subquery(latest_subquery.values('id')[:1])
)
qd = {q.tool_id: q for q in qs}
results = []
for tool in tools:
qv = qd.get(tool.id)
func = get_workflow_func(node, tool, qv, workspace_id)
args = get_workflow_args(tool, qv)
tool = StructuredTool.from_function(
func=func,
name=tool.name,
description=tool.desc,
args_schema=args,
)
results.append(tool)

return results
from tools.models import Tool, ToolType


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
Expand Down Expand Up @@ -362,7 +264,8 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}
mcp_servers_config = self.handle_variables(mcp_servers_config)
tool_init_params = {}
tools = get_tools(self, tool_ids, workspace_id)
tools = get_tools(self.workflow_manage.get_source_type(), self.workflow_manage.get_source_id(), tool_ids,
workspace_id)
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
self.context['tool_ids'] = tool_ids
for tool_id in tool_ids:
Expand Down
106 changes: 102 additions & 4 deletions apps/application/flow/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
@date:2024/6/6 15:15
@desc:
"""
from tools.models import ToolRecord, Tool, ToolScope
from langchain_core.tools import StructuredTool

from application.flow.common import Workflow, WorkflowMode
from application.serializers.common import ToolExecute
from tools.models import ToolRecord, Tool, ToolScope, ToolWorkflowVersion, ToolType
from maxkb.const import CONFIG
from knowledge.models.knowledge_action import State
from knowledge.models import File
from common.utils.logger import maxkb_logger
from common.result import result
from application.flow.i_step_node import WorkFlowPostHandler
from application.flow.i_step_node import WorkFlowPostHandler, ToolWorkflowPostHandler
from application.flow.backend.sandbox_shell import SandboxShellBackend
import asyncio
import io
Expand All @@ -25,11 +29,11 @@
import zipfile
from functools import reduce
from typing import Iterator

from pydantic import Field, create_model
import uuid_utils.compat as uuid
from asgiref.sync import sync_to_async
from deepagents import create_deep_agent
from django.db.models import QuerySet
from django.db.models import QuerySet, OuterRef, Subquery
from django.http import StreamingHttpResponse
from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk, SystemMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
Expand Down Expand Up @@ -990,3 +994,97 @@ def get_child_tool_id_list(work_flow, response):
for tool in tool_list:
response.append(str(tool.id))
return response


def build_schema(fields: dict):
return create_model("dynamicSchema", **fields)


def get_type(_type: str):
if _type == 'float':
return float
if _type == 'string':
return str
if _type == 'int':
return int
if _type == 'dict':
return dict
if _type == 'array':
return list
if _type == 'boolean':
return bool
return object


def get_workflow_args(tool, qv):
for node in qv.work_flow.get('nodes'):
if node.get('type') == 'tool-base-node':
input_field_list = node.get('properties').get('user_input_field_list')
return build_schema(
{field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
for field in input_field_list})

return build_schema({})


def get_workflow_func(source_type, source_id, tool, qv, workspace_id):
tool_id = tool.id
tool_record_id = str(uuid.uuid7())
took_execute = ToolExecute(tool_id, tool_record_id,
workspace_id,
source_type,
source_id,
False)

def inner(**kwargs):
from application.flow.tool_workflow_manage import ToolWorkflowManage
work_flow_manage = ToolWorkflowManage(
Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
{
'chat_record_id': tool_record_id,
'tool_id': tool_id,
'stream': True,
'workspace_id': workspace_id,
**kwargs},

ToolWorkflowPostHandler(took_execute, tool_id),
is_the_task_interrupted=lambda: False,
child_node=None,
start_node_id=None,
start_node_data=None,
chat_record=None
)
res = work_flow_manage.run()
for r in res:
pass
return work_flow_manage.out_context

return inner


def get_tools(source_type, source_id, tool_workflow_ids, workspace_id):
tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, tool_type=ToolType.WORKFLOW, workspace_id=workspace_id)
latest_subquery = ToolWorkflowVersion.objects.filter(
tool_id=OuterRef('tool_id')
).order_by('-create_time')

qs = ToolWorkflowVersion.objects.filter(
tool_id__in=[t.id for t in tools],
id=Subquery(latest_subquery.values('id')[:1])
)
qd = {q.tool_id: q for q in qs}
results = []
for tool in tools:
qv = qd.get(tool.id)
func = get_workflow_func(source_type, source_id, tool, qv,
workspace_id)
args = get_workflow_args(tool, qv)
tool = StructuredTool.from_function(
func=func,
name=tool.name,
description=tool.desc,
args_schema=args,
)
results.append(tool)

return results
Loading