Skip to content

Commit f455f2d

Browse files
committed
fix: update MCP request handling to include source ID and type
1 parent 203778d commit f455f2d

File tree

4 files changed

+112
-24
lines changed

4 files changed

+112
-24
lines changed

apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def reset_message_list(message_list: List[BaseMessage], answer_text):
227227
return result
228228

229229
def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
230-
application_ids, mcp_output_enable, chat_model, message_list):
230+
application_ids, mcp_output_enable, chat_model, message_list, agent_id):
231231

232232
mcp_servers_config = {}
233233

@@ -258,7 +258,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
258258
tool_init_params = json.loads(rsa_long_decrypt(tool.init_params))
259259
else:
260260
params = {}
261-
tool_config = executor.get_tool_mcp_config(tool.code, params, tool.name, tool.desc)
261+
tool_config = executor.get_tool_mcp_config(tool, params)
262262

263263
mcp_servers_config[str(tool.id)] = tool_config
264264

@@ -290,7 +290,10 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
290290
mcp_servers_config[app.name] = app_config
291291

292292
if len(mcp_servers_config) > 0:
293-
return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable, tool_init_params)
293+
source_id = agent_id
294+
source_type = 'APPLICATION'
295+
return mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable,
296+
tool_init_params, source_id, source_type)
294297

295298
return None
296299

@@ -304,7 +307,9 @@ def get_stream_result(self, message_list: List[BaseMessage],
304307
mcp_source="referencing",
305308
tool_ids=None,
306309
application_ids=None,
307-
mcp_output_enable=True):
310+
mcp_output_enable=True,
311+
agent_id=None
312+
):
308313
if paragraph_list is None:
309314
paragraph_list = []
310315
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
@@ -324,7 +329,7 @@ def get_stream_result(self, message_list: List[BaseMessage],
324329
mcp_result = self._handle_mcp_request(
325330
mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
326331
application_ids, mcp_output_enable, chat_model,
327-
message_list,
332+
message_list, agent_id
328333
)
329334
if mcp_result:
330335
return mcp_result, True
@@ -351,7 +356,7 @@ def execute_stream(self, message_list: List[BaseMessage],
351356
no_references_setting, problem_text, mcp_tool_ids,
352357
mcp_servers, mcp_source, tool_ids,
353358
application_ids,
354-
mcp_output_enable)
359+
mcp_output_enable, manage.context.get('application_id'))
355360
chat_record_id = self.context.get('step_args', {}).get('chat_record_id') if self.context.get('step_args',
356361
{}).get(
357362
'chat_record_id') else uuid.uuid7()
@@ -375,7 +380,8 @@ def get_block_result(self, message_list: List[BaseMessage],
375380
mcp_source="referencing",
376381
tool_ids=None,
377382
application_ids=None,
378-
mcp_output_enable=True
383+
mcp_output_enable=True,
384+
application_id=None
379385
):
380386
if paragraph_list is None:
381387
paragraph_list = []
@@ -395,7 +401,7 @@ def get_block_result(self, message_list: List[BaseMessage],
395401
mcp_result = self._handle_mcp_request(
396402
mcp_source, mcp_servers, mcp_tool_ids, tool_ids,
397403
application_ids, mcp_output_enable,
398-
chat_model, message_list,
404+
chat_model, message_list, application_id
399405
)
400406
if mcp_result:
401407
return mcp_result, True
@@ -429,7 +435,7 @@ def execute_block(self, message_list: List[BaseMessage],
429435
no_references_setting, problem_text,
430436
mcp_tool_ids, mcp_servers, mcp_source,
431437
tool_ids, application_ids,
432-
mcp_output_enable)
438+
mcp_output_enable, manage.context.get('application_id'))
433439
if is_ai_chat:
434440
request_token = chat_model.get_num_tokens_from_messages(message_list)
435441
response_token = chat_model.get_num_tokens(chat_result.content)

apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
243243
tool_init_params = json.loads(rsa_long_decrypt(tool.init_params))
244244
else:
245245
params = {}
246-
tool_config = executor.get_tool_mcp_config(tool.code, params, tool.name, tool.desc)
246+
tool_config = executor.get_tool_mcp_config(tool, params)
247247

248248
mcp_servers_config[str(tool.id)] = tool_config
249249

@@ -275,7 +275,17 @@ def _handle_mcp_request(self, mcp_source, mcp_servers, mcp_tool_id, mcp_tool_ids
275275
mcp_servers_config[app.name] = app_config
276276

277277
if len(mcp_servers_config) > 0:
278-
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable, tool_init_params)
278+
# 安全获取 application
279+
application_id = None
280+
if (self.workflow_manage and
281+
self.workflow_manage.work_flow_post_handler and
282+
self.workflow_manage.work_flow_post_handler.chat_info):
283+
application_id = self.workflow_manage.work_flow_post_handler.chat_info.application.id
284+
knowledge_id = self.workflow_params.get('knowledge_id')
285+
source_id = application_id or knowledge_id
286+
source_type = 'APPLICATION' if application_id else 'KNOWLEDGE'
287+
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config), mcp_output_enable,
288+
tool_init_params, source_id, source_type)
279289
return NodeResult(
280290
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
281291
'history_message': [{'content': message.content, 'role': message.type} for message in

apps/application/flow/tools.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from functools import reduce
1515
from typing import Iterator
1616

17+
import uuid_utils.compat as uuid
18+
from asgiref.sync import sync_to_async
19+
from django.db.models import QuerySet
1720
from django.http import StreamingHttpResponse
1821
from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk
1922
from langchain_mcp_adapters.client import MultiServerMCPClient
@@ -22,7 +25,9 @@
2225
from application.flow.i_step_node import WorkFlowPostHandler
2326
from common.result import result
2427
from common.utils.logger import maxkb_logger
28+
from knowledge.models.knowledge_action import State
2529
from maxkb.const import CONFIG
30+
from tools.models import ToolRecord, Tool
2631

2732

2833
class Reasoning:
@@ -316,7 +321,8 @@ def _extract_tool_id(raw_id):
316321
return tool_id or raw_id
317322

318323

319-
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}):
324+
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
325+
source_id=None, source_type=None):
320326
try:
321327
client = MultiServerMCPClient(json.loads(mcp_servers))
322328
tools = await client.get_tools()
@@ -393,10 +399,14 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_
393399

394400
if tool_id in tool_calls_info:
395401
tool_info = tool_calls_info[tool_id]
402+
tool_result = json.loads(chunk[0].content)
403+
tool_id = tool_result.pop('tool_id')
404+
if tool_id:
405+
await save_tool_record(tool_id, tool_info, tool_result, source_id, source_type)
396406
content = generate_tool_message_complete(
397407
tool_info['name'],
398408
tool_info['input'],
399-
chunk[0].content
409+
json.dumps(tool_result),
400410
)
401411
chunk[0].content = content
402412
else:
@@ -421,14 +431,30 @@ def get_real_error(exc):
421431
raise RuntimeError(error_msg) from None
422432

423433

424-
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}):
434+
async def save_tool_record(tool_id, tool_info, tool_result, source_id, source_type):
435+
tool = await sync_to_async(lambda: QuerySet(Tool).filter(id=tool_id).first())()
436+
tool_record = ToolRecord(
437+
id=uuid.uuid7(),
438+
workspace_id=tool.workspace_id,
439+
tool_id=tool_id,
440+
source_type=source_type,
441+
source_id=source_id,
442+
meta={'input': tool_info['input'], 'output': tool_result},
443+
state=State.SUCCESS
444+
)
445+
await sync_to_async(tool_record.save)()
446+
447+
448+
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
449+
source_id=None, source_type=None):
425450
"""使用全局事件循环,不创建新实例"""
426451
result_queue = queue.Queue()
427452
loop = get_global_loop() # 使用共享循环
428453

429454
async def _run():
430455
try:
431-
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params)
456+
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params,
457+
source_id, source_type)
432458
async for chunk in async_gen:
433459
result_queue.put(('data', chunk))
434460
except Exception as e:

apps/common/utils/tool_code.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def exec_code(self, code_str, keywords, function_name=None):
137137
return result.get('data')
138138
raise Exception(result.get('msg') + (f'\n{subprocess_result.stderr}' if subprocess_result.stderr else ''))
139139

140-
def _generate_mcp_server_code(self, _code, params, name=None, description=None):
141-
# 解析代码提取导入语句和函数定义
140+
def _generate_mcp_server_code(self, _code, params, name=None, description=None, tool_id=None):
141+
# 解析代码,提取导入语句和函数定义
142142
try:
143143
tree = ast.parse(_code)
144144
except SyntaxError:
@@ -155,7 +155,7 @@ def _generate_mcp_server_code(self, _code, params, name=None, description=None):
155155
continue
156156
# 修改函数参数以包含 params 中的默认值
157157
arg_names = [arg.arg for arg in node.args.args]
158-
# 为参数添加默认值确保参数顺序正确
158+
# 为参数添加默认值,确保参数顺序正确
159159
defaults = []
160160
num_defaults = 0
161161
# 从后往前检查哪些参数有默认值
@@ -178,26 +178,72 @@ def _generate_mcp_server_code(self, _code, params, name=None, description=None):
178178
else:
179179
defaults.append(ast.Constant(value=str(default_value)))
180180
else:
181-
# 如果某个参数没有默认值需要添加 None 占位
181+
# 如果某个参数没有默认值,需要添加 None 占位
182182
defaults.append(ast.Constant(value=None))
183183
node.args.defaults = defaults
184+
185+
# 修改返回类型注解为 Result
186+
node.returns = ast.Name(id='Result', ctx=ast.Load())
187+
188+
# 修改 return 语句为 return Result(result=..., tool_id=...)
189+
class ReturnTransformer(ast.NodeTransformer):
190+
def __init__(self, func_name):
191+
self.func_name = func_name
192+
193+
def visit_Return(self, node):
194+
if node.value is None:
195+
# return 语句没有返回值
196+
new_return = ast.Return(
197+
value=ast.Call(
198+
func=ast.Name(id='Result', ctx=ast.Load()),
199+
args=[],
200+
keywords=[
201+
ast.keyword(arg='result', value=ast.Constant(value=None)),
202+
ast.keyword(arg='tool_id', value=ast.Constant(value=tool_id))
203+
]
204+
)
205+
)
206+
else:
207+
# return 语句有返回值
208+
new_return = ast.Return(
209+
value=ast.Call(
210+
func=ast.Name(id='Result', ctx=ast.Load()),
211+
args=[],
212+
keywords=[
213+
ast.keyword(arg='result', value=node.value),
214+
ast.keyword(arg='tool_id', value=ast.Constant(value=tool_id))
215+
]
216+
)
217+
)
218+
return ast.copy_location(new_return, node)
219+
220+
transformer = ReturnTransformer(node.name)
221+
node = transformer.visit(node)
222+
ast.fix_missing_locations(node)
223+
184224
func_code = ast.unparse(node)
185-
# 有些模型不支持name是中文例如: deepseek, 其他模型未知
225+
# 有些模型不支持name是中文,例如: deepseek, 其他模型未知
186226
escaped_desc = (name + ' ' + description).replace('\n', ' ').replace("'", " ")
187227
functions.append(f"@mcp.tool(description='{escaped_desc}')\n{func_code}\n")
188228
else:
189229
other_code.append(ast.unparse(node))
230+
190231
# 构建完整的 MCP 服务器代码
191232
code_parts = ["from mcp.server.fastmcp import FastMCP"]
192233
code_parts.extend(imports)
234+
code_parts.append(f"\nfrom pydantic import BaseModel")
235+
code_parts.append(f"\nfrom typing import Any")
236+
code_parts.append(f"\nclass Result(BaseModel):")
237+
code_parts.append(f"\n\tresult: Any")
238+
code_parts.append(f"\n\ttool_id: str\n")
193239
code_parts.append(f"\nmcp = FastMCP(\"{uuid.uuid7()}\")\n")
194240
code_parts.extend(other_code)
195241
code_parts.extend(functions)
196242
code_parts.append("\nmcp.run(transport=\"stdio\")\n")
197243
return "\n".join(code_parts)
198244

199-
def generate_mcp_server_code(self, code_str, params, name, description):
200-
code = self._generate_mcp_server_code(code_str, params, name, description)
245+
def generate_mcp_server_code(self, code_str, params, name, description, tool_id):
246+
code = self._generate_mcp_server_code(code_str, params, name, description, tool_id)
201247
set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
202248
return f"""
203249
import os, sys, logging
@@ -212,8 +258,8 @@ def generate_mcp_server_code(self, code_str, params, name, description):
212258
exec({dedent(code)!a})
213259
"""
214260

215-
def get_tool_mcp_config(self, code, params, name, description):
216-
_code = self.generate_mcp_server_code(code, params, name, description)
261+
def get_tool_mcp_config(self, tool, params):
262+
_code = self.generate_mcp_server_code(tool.code, params, tool.name, tool.desc, str(tool.id))
217263
maxkb_logger.debug(f"Python code of mcp tool: {_code}")
218264
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
219265
tool_config = {

0 commit comments

Comments
 (0)