|
13 | 13 | import threading |
14 | 14 | from functools import reduce |
15 | 15 | from typing import Iterator |
16 | | -from maxkb.const import CONFIG |
| 16 | + |
17 | 17 | from django.http import StreamingHttpResponse |
18 | 18 | from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk |
19 | 19 | from langchain_mcp_adapters.client import MultiServerMCPClient |
20 | 20 | from langgraph.prebuilt import create_react_agent |
| 21 | + |
21 | 22 | from application.flow.i_step_node import WorkFlowPostHandler |
22 | 23 | from common.result import result |
23 | 24 | from common.utils.logger import maxkb_logger |
| 25 | +from maxkb.const import CONFIG |
24 | 26 |
|
25 | 27 |
|
26 | 28 | class Reasoning: |
@@ -104,7 +106,7 @@ def get_reasoning_content(self, chunk): |
104 | 106 | if reasoning_content_end_tag_index > -1: |
105 | 107 | reasoning_content_chunk = self.reasoning_content_chunk[0:reasoning_content_end_tag_index] |
106 | 108 | content_chunk = self.reasoning_content_chunk[ |
107 | | - reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] |
| 109 | + reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:] |
108 | 110 | self.reasoning_content += reasoning_content_chunk |
109 | 111 | self.content += content_chunk |
110 | 112 | self.reasoning_content_chunk = "" |
@@ -314,7 +316,7 @@ def _extract_tool_id(raw_id): |
314 | 316 | return tool_id or raw_id |
315 | 317 |
|
316 | 318 |
|
317 | | -async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True): |
| 319 | +async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}): |
318 | 320 | try: |
319 | 321 | client = MultiServerMCPClient(json.loads(mcp_servers)) |
320 | 322 | tools = await client.get_tools() |
@@ -362,11 +364,20 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_ |
362 | 364 | # 尝试解析 JSON,判断是否完整 |
363 | 365 | if entry['id'] and entry['arguments']: |
364 | 366 | try: |
365 | | - json.loads(entry['arguments']) |
| 367 | + parsed_args = json.loads(entry['arguments']) |
| 368 | + # 过滤掉 tool_init_params 中的参数 |
| 369 | + if tool_init_params: |
| 370 | + filtered_args = { |
| 371 | + k: v for k, v in parsed_args.items() |
| 372 | + if k not in tool_init_params |
| 373 | + } |
| 374 | + else: |
| 375 | + filtered_args = parsed_args |
| 376 | + |
366 | 377 | # JSON 完整,保存到 tool_calls_info |
367 | 378 | tool_calls_info[entry['id']] = { |
368 | 379 | 'name': entry['name'], |
369 | | - 'input': entry['arguments'] |
| 380 | + 'input': json.dumps(filtered_args, ensure_ascii=False) |
370 | 381 | } |
371 | 382 | # 从 fragments 中移除 |
372 | 383 | del _tool_fragments[idx] |
@@ -410,14 +421,14 @@ def get_real_error(exc): |
410 | 421 | raise RuntimeError(error_msg) from None |
411 | 422 |
|
412 | 423 |
|
413 | | -def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True): |
| 424 | +def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={}): |
414 | 425 | """使用全局事件循环,不创建新实例""" |
415 | 426 | result_queue = queue.Queue() |
416 | 427 | loop = get_global_loop() # 使用共享循环 |
417 | 428 |
|
418 | 429 | async def _run(): |
419 | 430 | try: |
420 | | - async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable) |
| 431 | + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params) |
421 | 432 | async for chunk in async_gen: |
422 | 433 | result_queue.put(('data', chunk)) |
423 | 434 | except Exception as e: |
|
0 commit comments