Skip to content

Commit d337758

Browse files
authored
Merge pull request #168 from AnguseZhang/dev/zhouh
refactor: streamline agent transition handling and enhance session tr…
2 parents 4dcc2d9 + d4a7879 commit d337758

4 files changed

Lines changed: 51 additions & 9 deletions

File tree

agents/matmaster_agent/agent.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,8 @@ async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event,
9090
# 对于 [matmaster_check_job_status] 生成的消息, 手动拼一个流式消息
9191
if ctx.session.state['special_llm_response']:
9292
yield frontend_text_event(ctx, self.name, event.content.parts[0].text, ModelRole)
93-
94-
ctx.session.state['special_llm_response'] = False
95-
await update_session_state(ctx, self.name)
93+
ctx.session.state['special_llm_response'] = False
94+
await update_session_state(ctx, self.name)
9695
yield event
9796
except BaseException as err:
9897
async for error_event in send_error_event(err, ctx, self.name):

agents/matmaster_agent/callback.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ async def matmaster_check_job_status(callback_context: CallbackContext, llm_resp
6464
running_job_ids = get_running_jobs_detail(jobs_dict)
6565
access_key = _get_ak(callback_context)
6666
if callback_context.state['target_language'] in ['Chinese', 'zh-CN', '简体中文', 'Chinese (Simplified)']:
67-
job_complete_intro = '检测到任务 <{job_id}> 已完成,我将立刻转移至对应的 Agent 去获取任务结果,请稍等...'
67+
job_complete_intro = '检测到任务 <{job_id}> 已完成,我将立刻转移至对应的 Agent 去获取任务结果'
6868
else:
6969
job_complete_intro = ('Job <{job_id}> has been detected as completed. '
70-
'I will immediately transfer to the corresponding agent to retrieve the job results. Please wait...')
70+
'I will immediately transfer to the corresponding agent to retrieve the job results.')
7171

7272
reset = False
7373
for origin_job_id, job_id, job_query_url, agent_name in running_job_ids:
@@ -116,7 +116,7 @@ async def matmaster_check_transfer(callback_context: CallbackContext, llm_respon
116116
is_transfer and
117117
not has_function_call(llm_response)
118118
):
119-
logger.warning(f"Detected Agent Transfer Hallucination, add `transfer_to_agent`")
119+
logger.warning(f"[matmaster_check_transfer] target_agent = {target_agent}")
120120
function_call_id = f"call_{str(uuid.uuid4()).replace('-', '')[:24]}"
121121
llm_response.content.parts.append(Part(function_call=FunctionCall(id=function_call_id, name='transfer_to_agent',
122122
args={'agent_name': target_agent})))

agents/matmaster_agent/model.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,23 @@
33

44
from pydantic import BaseModel, field_validator
55

6+
from agents.matmaster_agent.ABACUS_agent.constant import ABACUS_AGENT_NAME
7+
from agents.matmaster_agent.DPACalculator_agent.constant import DPACalulator_AGENT_NAME
8+
from agents.matmaster_agent.HEACalculator_agent.constant import HEACALCULATOR_AGENT_NAME
9+
from agents.matmaster_agent.HEA_assistant_agent.constant import HEA_assistant_AgentName
10+
from agents.matmaster_agent.INVAR_agent.constant import INVAR_AGENT_NAME
11+
from agents.matmaster_agent.MrDice_agent.constant import MrDice_Agent_Name
12+
from agents.matmaster_agent.apex_agent.constant import ApexAgentName
13+
from agents.matmaster_agent.chembrain_agent.constant import CHEMBRAIN_AGENT_NAME
14+
from agents.matmaster_agent.organic_reaction_agent.constant import ORGANIC_REACTION_AGENT_NAME
15+
from agents.matmaster_agent.perovskite_agent.constant import PerovskiteAgentName
16+
from agents.matmaster_agent.piloteye_electro_agent.constant import PILOTEYE_ELECTRO_AGENT_NAME
17+
from agents.matmaster_agent.ssebrain_agent.constant import SSEBRAIN_AGENT_NAME
18+
from agents.matmaster_agent.structure_generate_agent.constant import StructureGenerateAgentName
19+
from agents.matmaster_agent.superconductor_agent.constant import SuperconductorAgentName
20+
from agents.matmaster_agent.thermoelectric_agent.constant import ThermoelectricAgentName
21+
from agents.matmaster_agent.traj_analysis_agent.constant import TrajAnalysisAgentName
22+
623

724
class JobStatus(str, Enum):
825
Running = 'Running'
@@ -53,9 +70,28 @@ class DFlowJobInfo(BaseModel):
5370
job_in_ctx: bool = False
5471

5572

73+
class TargetAgentEnum(str, Enum):
74+
ABACUSAgent = ABACUS_AGENT_NAME
75+
APEXAgent = ApexAgentName
76+
ChemBrainAgent = CHEMBRAIN_AGENT_NAME
77+
DPACalculatorAgent = DPACalulator_AGENT_NAME
78+
HEAAssistantAgent = HEA_assistant_AgentName
79+
HEACalculatorAgent = HEACALCULATOR_AGENT_NAME
80+
INVARAgent = INVAR_AGENT_NAME
81+
MrDiceAgent = MrDice_Agent_Name
82+
OrganicReactionAgent = ORGANIC_REACTION_AGENT_NAME
83+
PerovskiteAgent = PerovskiteAgentName
84+
PiloteyeElectroAgent = PILOTEYE_ELECTRO_AGENT_NAME
85+
SSEBrainAgent = SSEBRAIN_AGENT_NAME
86+
StructureGenerateAgent = StructureGenerateAgentName
87+
SuperConductorAgent = SuperconductorAgentName
88+
ThermoElectricAgent = ThermoelectricAgentName
89+
TrajAnalysisAgent = TrajAnalysisAgentName
90+
91+
5692
class TransferCheck(BaseModel):
5793
is_transfer: bool
58-
target_agent: str
94+
target_agent: TargetAgentEnum
5995

6096

6197
class UserContent(BaseModel):

agents/matmaster_agent/utils/helper_func.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import copy
2+
import inspect
23
import json
34
import logging
5+
import os
46
from typing import List, Union
57

68
from google.adk.agents.invocation_context import InvocationContext
@@ -19,9 +21,14 @@ def get_session_state(ctx: Union[InvocationContext, ToolContext]):
1921

2022

2123
async def update_session_state(ctx: InvocationContext, author: str):
24+
stack = inspect.stack()
25+
frame = stack[1] # stack[1] 表示调用当前函数的上一层调用
26+
filename = os.path.basename(frame.filename)
27+
lineno = frame.lineno
2228
actions_with_update = EventActions(state_delta=ctx.session.state)
23-
system_event = Event(invocation_id=ctx.invocation_id, author=author, actions=actions_with_update)
24-
await ctx.session_service.append_event(ctx.session, system_event)
29+
system_event = Event(invocation_id=ctx.invocation_id, author=f"{filename}:{lineno}",
30+
actions=actions_with_update)
31+
await ctx.session_service.append_event(ctx.session, system_event) # 会引入一个空消息
2532

2633

2734
def update_llm_response(llm_response: LlmResponse, current_function_calls: List[dict],

0 commit comments

Comments
 (0)