Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ class InstanceSerializer(serializers.Serializer):

model_params_setting = serializers.DictField(required=False, allow_null=True,
label=_("Model parameter settings"))
mcp_enable = serializers.BooleanField(label="MCP否启用", required=False, default=False)
mcp_tool_ids = serializers.JSONField(label="MCP工具ID列表", required=False, default=list)
mcp_servers = serializers.JSONField(label="MCP服务列表", required=False, default=dict)
mcp_source = serializers.CharField(label="MCP Source", required=False, default="referencing")
tool_enable = serializers.BooleanField(label="工具是否启用", required=False, default=False)
tool_ids = serializers.JSONField(label="工具ID列表", required=False, default=list)

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
Expand All @@ -106,5 +112,8 @@ def execute(self, message_list: List[BaseMessage],
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
no_references_setting=None, model_params_setting=None, model_setting=None,
mcp_enable=False, mcp_tool_ids=None, mcp_servers='', mcp_source="referencing",
tool_enable=False, tool_ids=None,
**kwargs):
pass
130 changes: 116 additions & 14 deletions apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
@date:2024/1/9 18:25
@desc: 对话step Base实现
"""
import logging
import json
import os
import time
import traceback
import uuid_utils.compat as uuid
Expand All @@ -24,10 +25,14 @@
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning
from application.flow.tools import Reasoning, mcp_response_generator
from application.models import ApplicationChatUserStats, ChatUserType
from common.utils.logger import maxkb_logger
from common.utils.rsa_util import rsa_long_decrypt
from common.utils.tool_code import ToolExecutor
from maxkb.const import CONFIG
from models_provider.tools import get_model_instance_by_model_workspace_id
from tools.models import Tool


def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
Expand All @@ -54,6 +59,7 @@ def write_context(step, manage, request_token, response_token, all_text):
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token



def event_content(response,
chat_id,
chat_record_id,
Expand Down Expand Up @@ -169,6 +175,12 @@ def execute(self, message_list: List[BaseMessage],
no_references_setting=None,
model_params_setting=None,
model_setting=None,
mcp_enable=False,
mcp_tool_ids=None,
mcp_servers='',
mcp_source="referencing",
tool_enable=False,
tool_ids=None,
**kwargs):
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting) if model_id is not None else None
Expand All @@ -177,14 +189,24 @@ def execute(self, message_list: List[BaseMessage],
paragraph_list,
manage, padding_problem_text, chat_user_id, chat_user_type,
no_references_setting,
model_setting)
model_setting,
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
model_setting)
model_setting,
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)

def get_details(self, manage, **kwargs):
# 删除临时生成的MCP代码文件
if self.context.get('execute_ids'):
executor = ToolExecutor(CONFIG.get('SANDBOX'))
# 清理工具代码文件,延时删除,避免文件被占用
for tool_id in self.context.get('execute_ids'):
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
if os.path.exists(code_path):
os.remove(code_path)
return {
'step_type': 'chat_step',
'run_time': self.context['run_time'],
Expand All @@ -206,12 +228,63 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):
result.append({'role': 'ai', 'content': answer_text})
return result

@staticmethod
def get_stream_result(message_list: List[BaseMessage],
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
chat_model, message_list):
if not mcp_enable and not tool_enable:
return None

mcp_servers_config = {}

# 迁移过来mcp_source是None
if mcp_source is None:
mcp_source = 'custom'
if mcp_enable:
# 兼容老数据
if not mcp_tool_ids:
mcp_tool_ids = []
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
mcp_servers_config = json.loads(mcp_servers)
elif mcp_tool_ids:
mcp_tools = QuerySet(Tool).filter(id__in=mcp_tool_ids).values()
for mcp_tool in mcp_tools:
if mcp_tool and mcp_tool['is_active']:
mcp_servers_config = {**mcp_servers_config, **json.loads(mcp_tool['code'])}

if tool_enable:
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
self.context['tool_ids'] = tool_ids
self.context['execute_ids'] = []
for tool_id in tool_ids:
tool = QuerySet(Tool).filter(id=tool_id).first()
if not tool.is_active:
continue
executor = ToolExecutor(CONFIG.get('SANDBOX'))
if tool.init_params is not None:
params = json.loads(rsa_long_decrypt(tool.init_params))
else:
params = {}
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)

self.context['execute_ids'].append(_id)
mcp_servers_config[str(tool.id)] = tool_config

if len(mcp_servers_config) > 0:
return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config))

return None


def get_stream_result(self, message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None,
problem_text=None):
problem_text=None,
mcp_enable=False,
mcp_tool_ids=None,
mcp_servers='',
mcp_source="referencing",
tool_enable=False,
tool_ids=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
Expand All @@ -227,6 +300,12 @@ def get_stream_result(message_list: List[BaseMessage],
return iter([AIMessageChunk(
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.'))]), False
else:
# 处理 MCP 请求
mcp_result = self._handle_mcp_request(
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list,
)
if mcp_result:
return mcp_result, True
return chat_model.stream(message_list), True

def execute_stream(self, message_list: List[BaseMessage],
Expand All @@ -239,9 +318,15 @@ def execute_stream(self, message_list: List[BaseMessage],
padding_problem_text: str = None,
chat_user_id=None, chat_user_type=None,
no_references_setting=None,
model_setting=None):
model_setting=None,
mcp_enable=False,
mcp_tool_ids=None,
mcp_servers='',
mcp_source="referencing",
tool_enable=False,
tool_ids=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text)
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
chat_record_id = uuid.uuid7()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
Expand All @@ -253,12 +338,17 @@ def execute_stream(self, message_list: List[BaseMessage],
r['Cache-Control'] = 'no-cache'
return r

@staticmethod
def get_block_result(message_list: List[BaseMessage],
def get_block_result(self, message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None,
problem_text=None):
problem_text=None,
mcp_enable=False,
mcp_tool_ids=None,
mcp_servers='',
mcp_source="referencing",
tool_enable=False,
tool_ids=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
Expand All @@ -273,6 +363,12 @@ def get_block_result(message_list: List[BaseMessage],
return AIMessage(
_('Sorry, the AI model is not configured. Please go to the application to set up the AI model first.')), False
else:
# 处理 MCP 请求
mcp_result = self._handle_mcp_request(
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, chat_model, message_list,
)
if mcp_result:
return mcp_result, True
return chat_model.invoke(message_list), True

def execute_block(self, message_list: List[BaseMessage],
Expand All @@ -284,7 +380,13 @@ def execute_block(self, message_list: List[BaseMessage],
manage: PipelineManage = None,
padding_problem_text: str = None,
chat_user_id=None, chat_user_type=None, no_references_setting=None,
model_setting=None):
model_setting=None,
mcp_enable=False,
mcp_tool_ids=None,
mcp_servers='',
mcp_source="referencing",
tool_enable=False,
tool_ids=None):
reasoning_content_enable = model_setting.get('reasoning_content_enable', False)
reasoning_content_start = model_setting.get('reasoning_content_start', '<think>')
reasoning_content_end = model_setting.get('reasoning_content_end', '</think>')
Expand All @@ -294,7 +396,7 @@ def execute_block(self, message_list: List[BaseMessage],
# 调用模型
try:
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text)
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,59 +6,28 @@
@date:2024/6/4 14:30
@desc:
"""
import asyncio
import json
import os
import re
import sys
import time
import traceback
from functools import reduce
from typing import List, Dict

import uuid_utils.compat as uuid
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage, AIMessage, AIMessageChunk, ToolMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import BaseMessage, AIMessage


from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning
from common.utils.logger import maxkb_logger
from application.flow.tools import Reasoning, mcp_response_generator
from common.utils.rsa_util import rsa_long_decrypt
from common.utils.tool_code import ToolExecutor
from maxkb.const import CONFIG
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
from tools.models import Tool

tool_message_template = """
<details>
<summary>
<strong>Called MCP Tool: <em>%s</em></strong>
</summary>

%s

</details>

"""

tool_message_json_template = """
```json
%s
```
"""


def generate_tool_message_template(name, context):
if '```' in context:
return tool_message_template % (name, context)
else:
return tool_message_template % (name, tool_message_json_template % (context))


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
reasoning_content: str):
Expand Down Expand Up @@ -122,39 +91,6 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)


async def _yield_mcp_response(chat_model, message_list, mcp_servers):
client = MultiServerMCPClient(json.loads(mcp_servers))
tools = await client.get_tools()
agent = create_react_agent(chat_model, tools)
response = agent.astream({"messages": message_list}, stream_mode='messages')
async for chunk in response:
if isinstance(chunk[0], ToolMessage):
content = generate_tool_message_template(chunk[0].name, chunk[0].content)
chunk[0].content = content
yield chunk[0]
if isinstance(chunk[0], AIMessageChunk):
yield chunk[0]


def mcp_response_generator(chat_model, message_list, mcp_servers):
loop = asyncio.new_event_loop()
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers)
while True:
try:
chunk = loop.run_until_complete(anext_async(async_gen))
yield chunk
except StopAsyncIteration:
break
except Exception as e:
maxkb_logger.error(f'Exception: {e}', traceback.format_exc())
finally:
loop.close()


async def anext_async(agen):
return await agen.__anext__()


def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
"""
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code snippet has several issues that need to be addressed:

  1. Import Order: The uuid_utils.compat as uuid import is not used anywhere in the code.
  2. Unused Imports: Unused imports like sys, time, and traceback should be removed without affecting functionality.
  3. Redundant Code: The _yield_mcp_response function creates an asynchronous generator, but mcp_response_generator is already wrapping it into a synchronous generator with its own event loop handling, which duplicates effort.
  4. Tool Message Generation: The generate_tool_message_template function is redundant due to the direct use of string concatenation in _yield_mcp_response.
  5. Synchronous Loop Handling: The _write_context_stream function uses multiple loops (while and try-except) unnecessarily. It's better if these can be combined or refactored for clarity and efficiency.

Here’s an optimized version of the code with improvements:

import asyncio
import json
import os
import re
from functools import reduce
from typing import List, Dict

from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage, AIMessage
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.prebuilt import create_react_agent

from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning
from common.utils.logger import maxkb_logger
from common.utils.rsa_util import rsa_long_decrypt
from common.utils.tool_code import ToolExecutor
from maxkb.const import CONFIG
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
from tools.models import Tool

tool_message_template = "<details>\n<summary><strong>Called MCP Tool: <em>{name}</em></strong></summary>\n{context}\n</details>"
tool_message_json_template = "```json\n{context}\n```"

async def _yield_mcp_response(chat_model, message_list, mcp_servers):
    client = MultiServerMCPClient(json.loads(mcp_servers))
    tools = await client.get_tools()
    agent = create_react_agent(chat_model, tools)
    response = agent.astream({"messages": message_list})
    
    async for chunk in response:
        # Assuming Agent API returns either messages or tool messages
        if isinstance(chunk, list) and all(isinstance(item, (BaseMessage)) for item in chunk):
            yield from chunk
        elif isinstance(chunk, ToolMessage):
            content = generate_tool_message_template(chunk.name, chunk.content)
            chunk.content = content
            yield chunk


def mcp_response_generator(chat_model, message_list, mcp_servers):
    return _yield_mcp_response(chat_model, message_list, mcp_servers)


async def anext_async(agen):
    return await agen.asend(None)


def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workspace_id, answer: str, reasoning_content: str):
    """Write context logic here."""
    pass


def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workspace_id, answer: str, reasoning_content: str):
    stream = mcp_response_generator(node.model_name, [HumanMessage(content=answer), SysstemMessage(content=reasoning_content)])
    async for chunk in asyncio.to_thread(stream.__anext__):  # Use asyncio.to_thread to run the async generator loop in another thread
        print(chunk.content)  # Simplified logging; replace with actual context handling.


@reduce
def update_context(acc, key, value):
    acc[key] = value
    return acc

# Example usage in main execution logic
result = await write_context_stream(..., ...)

Key Changes:

  • Removed Unnecessary Imports: Reduced the number of imports by removing unused ones.
  • Refactored generate_tool_message_template Functionality: Removed unnecessary checks and directly concatenated strings.
  • Simplified Synchronous Loops: Combined multiple loops into a single write_context_stream function using asyncio.to_thread to handle asynchronous processing synchronously.
  • Updated Comments: Improved comments for clarity.
  • Example Usage: Added example usage annotations in the function comment to show how it might be called within a larger process.

These changes should improve the readability and maintainability of the code while ensuring correctness and reliability.

Expand Down
Loading
Loading