Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

"""


def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str,
reasoning_content: str):
chat_model = node_variable.get('chat_model')
Expand Down Expand Up @@ -102,7 +103,6 @@ def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INo
_write_context(node_variable, workflow_variable, node, workflow, answer, reasoning_content)



async def _yield_mcp_response(chat_model, message_list, mcp_servers):
async with MultiServerMCPClient(json.loads(mcp_servers)) as client:
agent = create_react_agent(chat_model, client.get_tools())
Expand All @@ -115,6 +115,7 @@ async def _yield_mcp_response(chat_model, message_list, mcp_servers):
if isinstance(chunk[0], AIMessageChunk):
yield chunk[0]


def mcp_response_generator(chat_model, message_list, mcp_servers):
loop = asyncio.new_event_loop()
try:
Expand All @@ -130,6 +131,7 @@ def mcp_response_generator(chat_model, message_list, mcp_servers):
finally:
loop.close()


async def anext_async(agen):
return await agen.__anext__()

Expand Down Expand Up @@ -186,7 +188,8 @@ def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.context['reasoning_content'] = details.get('reasoning_content')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

  1. Docstring Improvements: The docstrings could be more detailed, especially explaining what each function is intended to do.

  2. Variable Naming Consistency: Use descriptive variable names instead of node_variable, workflow_variable, etc., which can make the code easier to understand.

  3. Function Annotations: Add annotations to functions to clarify their parameters and return types (where applicable).

  4. Asynchronous Functions: Ensure that all asynchronous operations are properly awaited or returned in coroutines for better readability.

Here’s the modified version with some improvements:

from typing import Dict, AsyncGenerator
import asyncio
from langchain.agents.react import BaseAgent
from langchain.agents.agent_toolkits.multipleservers.mcp_client import MultiServerMCPClient
from langchain.chat_models.base import BaseChatModel
from custom_chat_model import create_react_agent  # Assuming this module exists
from dataclasses import dataclass

@dataclass
class Context:
    context: dict
    answer_text: str

async def yield_mcp_response(chat_model: BaseChatModel, message_list: List[str], mcp_servers: List[str]) -> AsyncGenerator[AIMessageChunk, None]:
    async with MultiServerMCPClient(json.loads(mcp_servers)) as client:
        agent = create_react_agent(chat_model, client.get_tools())
        response_stream = await agent.call_with_retry(message_list)
        for chunk in response_stream:
            if isinstance(chunk[0], AIMessageChunk):
                yield chunk[0]

def mcp_response_generator(
    chat_model: BaseChatModel, 
    messages: List[str], 
    mcp_servers: List[str] = []
) -> AsyncIterator[AIMessageChunk]:
    loop = asyncio.new_event_loop()
    try:
        while True:
            response_gen = yield_mcp_response(chat_model, messages, mcp_servers)
            next_chunk = await anext_async(response_gen)
            print(next_chunk.content)
    finally:
        loop.close()

context = None

async def save_context(details: Dict, workflow_manager: Any) -> None:
    global context
    context = Context(context=details, answer_text=None)

async def execute(model_id: int, system:str, prompt:str, dialogue_number:int, history_chat_record:any, stream: bool, chat_id: int, chat_record_id: int,
                  model_params_setting: any = None, node_params: Dict[str, object] = {}) -> str:
    global context
    context.answer_text = details.get('answer')

# Example usage:
# result = await save_context({'answer': 'Hello'}, workflow_manager)

Key Changes:

  • Added type hints for improved clarity.
  • Used AsyncGenerator for yielding chunks from MCP server responses.
  • Simplified example usage at the bottom.
  • Ensured proper handling of asynchronous generator within mcp_response_generator.

These changes should improve the readability and maintainability of the provided code.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def save_context(self, details, workflow_manage):
self.context['question'] = details.get('question')
self.context['type'] = details.get('type')
self.context['reasoning_content'] = details.get('reasoning_content')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type,
app_document_list=None, app_image_list=None, app_audio_list=None, child_node=None, node_data=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
class BaseReplyNode(IReplyNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult:
if reply_type == 'referencing':
result = self.get_reference_content(fields)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def save_context(self, details, workflow_manage):
self.context['start_time'] = details.get('start_time')
self.context['form_data'] = form_data
self.context['is_submit'] = details.get('is_submit')
self.answer_text = details.get('result')
if self.node_params.get('is_result', False):
self.answer_text = details.get('result')
if form_data is not None:
for key in form_data:
self.context[key] = form_data[key]
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_answer_list(self) -> List[Answer] | None:
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
'form_data': self.context.get('form_data', {}),
"is_submit": self.context.get("is_submit", False)}
form = f'<form_rander>{json.dumps(form_setting,ensure_ascii=False)}</form_rander>'
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
context = self.workflow_manage.get_workflow_content()
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
Expand All @@ -85,7 +86,7 @@ def get_details(self, index: int, **kwargs):
"chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
'form_data': self.context.get('form_data', {}),
"is_submit": self.context.get("is_submit", False)}
form = f'<form_rander>{json.dumps(form_setting,ensure_ascii=False)}</form_rander>'
form = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'
context = self.workflow_manage.get_workflow_content()
form_content_format = self.workflow_manage.reset_prompt(form_content_format)
prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2')
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided Python code has some minor improvements and corrections. Here are the main points to address:

  1. In the save_context method, remove the unnecessary indentation for setting self.node_params.get('is_result', False) inside the if self.nodew_params.get... line.

  2. Ensure that the variable names are consistent throughout the function.

  3. Clean up the get_answer_list and get_details methods by removing the duplicated code where the form is created with JSON data wrapping it.

Here's the optimized code snippet:

def save_context(self, details, workflow_manage):
    # Add more context items here.
    self.context['start_time'] = details.get('start_time')
    self.context['form_data'] = details.get('form_data')
    self.context['is_submit'] = details.get('is_submit')
    if self.node_params.get('is_result', False):
        self.answer_text = details.get('result')

def get_answer_list(self) -> List[Answer] | None:
    if not self.flow_params_serializer or not self.flow_params_serializer.data:
        return []  # Check flow_params_serializer before accessing its fields

    form_setting = {
        "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"),
        "form_data": self.context.get('form_data', {}),
        "is_submit": self.context.get("is_submit", False)
    }
    
    form_rander = f'<form_rander>{json.dumps(form_setting, ensure_ascii=False)}</form_rander>'

    context = self.workflow_manage.get_workflow_content()
    form_content_format = self.workflow_manage.reset_prompt(form_content_format)

    prompt_template = PromptTemplate.from_template(
        form_content_format,
        template_format='jinja2'
    )

    # Return list of answers here using the prompt_template and context.

def get_details(self, index: int, **kwargs):
    # Similarly optimize this method to avoid duplication of logic in creating 'form' variable.

These changes should help make the code cleaner and potentially more efficient, while ensuring consistency across function definitions and eliminating redundancy.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def valid_function(function_lib, user_id):
class BaseFunctionLibNodeNode(IFunctionLibNode):
def save_context(self, details, workflow_manage):
self.context['result'] = details.get('result')
self.answer_text = str(details.get('result'))
if self.node_params.get('is_result'):
self.answer_text = str(details.get('result'))

def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult:
function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def convert_value(name: str, value, _type, is_required, source, node):
class BaseFunctionNodeNode(IFunctionNode):
def save_context(self, details, workflow_manage):
self.context['result'] = details.get('result')
self.answer_text = str(details.get('result'))
if self.node_params.get('is_result', False):
self.answer_text = str(details.get('result'))

def execute(self, input_field_list, code, **kwargs) -> NodeResult:
params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ class BaseImageGenerateNode(IImageGenerateNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
model_params_setting,
chat_record_id,
**kwargs) -> NodeResult:
print(model_params_setting)
application = self.workflow_manage.work_flow_post_handler.chat_info.application
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The provided code snippet has several improvements that can be made for clarity, efficiency, and better readability:

Improvements Summary:

  1. Removed Redundant Code: The code was unnecessarily adding self.answer_text regardless of the value of is_result. Removed this unnecessary assignment since it can cause confusion.

  2. Consistency in Parameter Access: Ensure consistent access to parameters within methods and functions.

  3. Print Statement Adjustment: Move printing model_params_setting outside of conditional statements to ensure consistency.

  4. Optimization: Minor adjustments to improve code structure.

Here is the revised version of the code with these changes applied:

@@ -16,18 +15,23 @@
class BaseImageGenerateNode(IImageGenerateNode):
    def save_context(self, details, workflow_manage):
        self.context['answer'] = details.get('answer', None)
        self.context['question'] = details.get('question', None)
+        answer_to_save = details.get('answer')
+        if not isinstance(answer_to_save, basestring):
+            answer_to_save = ''
+        if self.node_params.get('is_result', False):
+            self.answer_text = answer_to_save

    def execute(
        self,
        model_id: int, 
        prompt: Union[str], 
        negative_prompt: Optional[Union[str]], 
        dialogue_number: int, 
        dialogue_type: str, 
        history_chat_record: List[str],
        chat_id: str,
        model_params_setting: Dict[Any, Any] = {},
        chat_record_id: Optional[int] = None,
        **kwargs
    ) -> NodeResult:
        application = self.workflow_manage.work_flow_post_handler.chat_info.application
        tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
        history_message = self.get_history_message(history_chat_record, dialogue_number)
        self.context['history_message'] = history_message
        question = self.generate_prompt_question(prompt)

        # Print statement should ideally have no side effects unless necessary.
        print(f"Model Parameters Setting: {model_params_setting}")

These modifications make the code cleaner and more understandable. Adjustments like converting variables to strings where appropriate help maintain type safety without causing runtime exceptions. Printing statements related to input data (model_params_setting) are also adjusted for clarity.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
model_params_setting,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def save_context(self, details, workflow_manage):
self.context['result'] = details.get('result')
self.context['tool_params'] = details.get('tool_params')
self.context['mcp_tool'] = details.get('mcp_tool')
self.answer_text = details.get('result')
if self.node_params.get('is_result', False):
self.answer_text = details.get('result')

def execute(self, mcp_servers, mcp_server, mcp_tool, tool_params, **kwargs) -> NodeResult:
servers = json.loads(mcp_servers)
Expand All @@ -27,7 +28,8 @@ async def call_tool(s, session, t, a):
return s

res = asyncio.run(call_tool(servers, mcp_server, mcp_tool, params))
return NodeResult({'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {})
return NodeResult(
{'result': [content.text for content in res.content], 'tool_params': params, 'mcp_tool': mcp_tool}, {})

def handle_variables(self, tool_params):
# 处理参数中的变量
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['message_tokens'] = details.get('message_tokens')
self.context['answer_tokens'] = details.get('answer_tokens')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
model_params_setting=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class BaseSpeechToTextNode(ISpeechToTextNode):

def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id'))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
class BaseTextToSpeechNode(ITextToSpeechNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.answer_text = details.get('answer')
if self.node_params.get('is_result', False):
self.answer_text = details.get('answer')

def execute(self, tts_model_id, chat_id,
content, model_params_setting=None,
Expand Down