Skip to content

Commit 45ee077

Browse files
committed
refactor: streamline matmaster_agent transfer logic by abstracting callback functions and modifying the agent model structure
1 parent 2e1b884 commit 45ee077

7 files changed

Lines changed: 89 additions & 61 deletions

File tree

agents/matmaster_agent/agent.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,18 @@
1414
from agents.matmaster_agent.MrDice_agent.agent import init_MrDice_agent
1515
from agents.matmaster_agent.apex_agent.agent import init_apex_agent
1616
from agents.matmaster_agent.base_agents.io_agent import HandleFileUploadLlmAgent
17-
from agents.matmaster_agent.callback import matmaster_prepare_state, matmaster_check_transfer, matmaster_set_lang, \
18-
matmaster_check_job_status
17+
from agents.matmaster_agent.callback import matmaster_prepare_state, matmaster_set_lang, matmaster_check_job_status
1918
from agents.matmaster_agent.chembrain_agent.agent import init_chembrain_agent
2019
from agents.matmaster_agent.constant import MATMASTER_AGENT_NAME, ModelRole
2120
from agents.matmaster_agent.document_parser_agent.agent import init_document_parser_agent
2221
from agents.matmaster_agent.llm_config import MatMasterLlmConfig
22+
from agents.matmaster_agent.model import MatMasterTargetAgentEnum
2323
from agents.matmaster_agent.organic_reaction_agent.agent import init_organic_reaction_agent
2424
from agents.matmaster_agent.perovskite_agent.agent import init_perovskite_agent
2525
from agents.matmaster_agent.piloteye_electro_agent.agent import init_piloteye_electro_agent
26-
from agents.matmaster_agent.prompt import AgentDescription, AgentInstruction, GlobalInstruction
26+
from agents.matmaster_agent.prompt import AgentDescription, AgentInstruction, GlobalInstruction, \
27+
MatMasterCheckTransferPrompt
28+
from agents.matmaster_agent.public.callback import check_transfer
2729
from agents.matmaster_agent.ssebrain_agent.agent import init_ssebrain_agent
2830
from agents.matmaster_agent.structure_generate_agent.agent import init_structure_generate_agent
2931
from agents.matmaster_agent.superconductor_agent.agent import init_superconductor_agent
@@ -82,8 +84,9 @@ def __init__(self, llm_config):
8284
instruction=AgentInstruction,
8385
description=AgentDescription,
8486
before_agent_callback=[matmaster_prepare_state, matmaster_set_lang],
85-
after_model_callback=[matmaster_check_job_status, matmaster_check_transfer],
86-
)
87+
after_model_callback=[matmaster_check_job_status,
88+
check_transfer(prompt=MatMasterCheckTransferPrompt,
89+
target_agent_enum=MatMasterTargetAgentEnum)])
8790

8891
@override
8992
async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event, None]:

agents/matmaster_agent/callback.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212

1313
from agents.matmaster_agent.base_agents.callback import _get_ak
1414
from agents.matmaster_agent.constant import FRONTEND_STATE_KEY
15-
from agents.matmaster_agent.model import TransferCheck, UserContent
16-
from agents.matmaster_agent.prompt import get_transfer_check_prompt, get_user_content_lang
15+
from agents.matmaster_agent.model import UserContent
16+
from agents.matmaster_agent.prompt import get_user_content_lang
1717
from agents.matmaster_agent.utils.job_utils import get_job_status, has_job_running, get_running_jobs_detail
18-
from agents.matmaster_agent.utils.llm_response_utils import has_function_call
1918

2019
logger = logging.getLogger(__name__)
2120

@@ -109,40 +108,3 @@ async def matmaster_check_job_status(callback_context: CallbackContext, llm_resp
109108

110109
callback_context.state['last_llm_response_partial'] = llm_response.partial
111110
return None
112-
113-
114-
async def matmaster_check_transfer(callback_context: CallbackContext, llm_response: LlmResponse) -> Optional[
115-
LlmResponse]:
116-
# 检查响应是否有效
117-
if not (
118-
llm_response and
119-
not llm_response.partial and
120-
llm_response.content and
121-
llm_response.content.parts and
122-
len(llm_response.content.parts) and
123-
llm_response.content.parts[0].text
124-
):
125-
return None
126-
127-
prompt = get_transfer_check_prompt().format(response_text=llm_response.content.parts[0].text)
128-
response = litellm.completion(model='azure/gpt-4o', messages=[{'role': 'user', 'content': prompt}],
129-
response_format=TransferCheck)
130-
131-
result: dict = json.loads(response.choices[0].message.content)
132-
is_transfer = bool(result.get('is_transfer', False))
133-
target_agent = str(result.get('target_agent', ''))
134-
reason = str(result.get('reason', ''))
135-
logger.info(f"[matmaster_check_transfer] target_agent = {target_agent}, is_transfer = {is_transfer}"
136-
f"response_text = {llm_response.content.parts[0].text}, reason = {reason}")
137-
if (
138-
is_transfer and
139-
not has_function_call(llm_response)
140-
):
141-
logger.warning(f"[matmaster_check_transfer] add `transfer_to_agent`")
142-
function_call_id = f"added_{str(uuid.uuid4()).replace('-', '')[:24]}"
143-
llm_response.content.parts.append(Part(function_call=FunctionCall(id=function_call_id, name='transfer_to_agent',
144-
args={'agent_name': target_agent})))
145-
146-
return llm_response
147-
148-
return None

agents/matmaster_agent/model.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ class DFlowJobInfo(BaseModel):
7171
job_in_ctx: bool = False
7272

7373

74-
class TargetAgentEnum(str, Enum):
74+
class ParamsCheckComplete(BaseModel):
75+
flag: bool
76+
reason: str
77+
analyzed_messages: List[str]
78+
79+
80+
class MatMasterTargetAgentEnum(str, Enum):
7581
ABACUSAgent = ABACUS_AGENT_NAME
7682
APEXAgent = ApexAgentName
7783
ChemBrainAgent = CHEMBRAIN_AGENT_NAME
@@ -91,17 +97,5 @@ class TargetAgentEnum(str, Enum):
9197
TrajAnalysisAgent = TrajAnalysisAgentName
9298

9399

94-
class ParamsCheckComplete(BaseModel):
95-
flag: bool
96-
reason: str
97-
analyzed_messages: List[str]
98-
99-
100-
class TransferCheck(BaseModel):
101-
is_transfer: bool
102-
target_agent: TargetAgentEnum
103-
reason: str
104-
105-
106100
class UserContent(BaseModel):
107101
language: str

agents/matmaster_agent/prompt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,10 +715,8 @@ def gen_params_check_info_agent_instruction():
715715
ResultCoreAgentDescription = 'Provides real-time task status updates and result forwarding to UI'
716716
TransferAgentDescription = 'Transfer to proper agent to answer user query'
717717

718-
719718
# LLM-Helper Prompt
720-
def get_transfer_check_prompt():
721-
return """
719+
MatMasterCheckTransferPrompt = """
722720
You are an expert judge tasked with evaluating whether the previous LLM's response contains a clear and explicit request or instruction to transfer the conversation to a specific agent (e.g., 'xxx agent').
723721
Analyze the provided RESPONSE TEXT to determine if it explicitly indicates a transfer action.
724722
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import json
2+
import logging
3+
import uuid
4+
from enum import Enum
5+
from typing import Optional, Callable, Type
6+
7+
import litellm
8+
from google.adk.agents.callback_context import CallbackContext
9+
from google.adk.agents.llm_agent import AfterModelCallback
10+
from google.adk.models import LlmResponse
11+
from google.genai.types import Part, FunctionCall
12+
13+
from agents.matmaster_agent.utils.llm_response_utils import has_function_call
14+
from agents.matmaster_agent.utils.model_utils import create_transfer_check_model
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def check_transfer(prompt: str, target_agent_enum: Type[Enum]) -> AfterModelCallback:
20+
async def wrapper(callback_context: CallbackContext, llm_response: LlmResponse) -> Optional[
21+
LlmResponse]:
22+
# 检查响应是否有效
23+
if not (
24+
llm_response and
25+
not llm_response.partial and
26+
llm_response.content and
27+
llm_response.content.parts and
28+
len(llm_response.content.parts) and
29+
llm_response.content.parts[0].text
30+
):
31+
return None
32+
33+
llm_prompt = prompt.format(response_text=llm_response.content.parts[0].text)
34+
response = litellm.completion(model='azure/gpt-4o', messages=[{'role': 'user', 'content': llm_prompt}],
35+
response_format=create_transfer_check_model(target_agent_enum))
36+
37+
result: dict = json.loads(response.choices[0].message.content)
38+
is_transfer = bool(result.get('is_transfer', False))
39+
target_agent = str(result.get('target_agent', ''))
40+
reason = str(result.get('reason', ''))
41+
symbol_name = f"[{callback_context.agent_name.replace('_agent', '')}_check_transfer]"
42+
logger.info(f"{symbol_name} target_agent = {target_agent}, is_transfer = {is_transfer}, "
43+
f"response_text = {llm_response.content.parts[0].text}, reason = {reason}")
44+
if (
45+
is_transfer and
46+
not has_function_call(llm_response)
47+
):
48+
logger.warning(f"{symbol_name} add `transfer_to_agent`")
49+
function_call_id = f"added_{str(uuid.uuid4()).replace('-', '')[:24]}"
50+
llm_response.content.parts.append(
51+
Part(function_call=FunctionCall(id=function_call_id, name='transfer_to_agent',
52+
args={'agent_name': target_agent})))
53+
54+
return llm_response
55+
56+
return None
57+
58+
return wrapper
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from pydantic import create_model, BaseModel
2+
3+
4+
def create_transfer_check_model(agent_type):
5+
"""动态创建具有特定 agent 类型的 TransferCheck 模型"""
6+
return create_model(
7+
'DynamicTransferCheck',
8+
is_transfer=(bool, ...),
9+
target_agent=(agent_type, ...),
10+
reason=(str, ...),
11+
__base__=BaseModel
12+
)

0 commit comments

Comments
 (0)