Skip to content

Commit 4dcc2d9

Browse files
authored
Merge pull request #167 from AnguseZhang/dev/zhouh
feat: integrate session state update and frontend text event handling…
2 parents 4b0cc1c + fd27f82 commit 4dcc2d9

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

agents/matmaster_agent/agent.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from agents.matmaster_agent.callback import matmaster_prepare_state, matmaster_check_transfer, matmaster_set_lang, \
1818
matmaster_check_job_status
1919
from agents.matmaster_agent.chembrain_agent.agent import init_chembrain_agent
20-
from agents.matmaster_agent.constant import MATMASTER_AGENT_NAME
20+
from agents.matmaster_agent.constant import MATMASTER_AGENT_NAME, ModelRole
2121
from agents.matmaster_agent.llm_config import MatMasterLlmConfig
2222
from agents.matmaster_agent.organic_reaction_agent.agent import init_organic_reaction_agent
2323
from agents.matmaster_agent.perovskite_agent.agent import init_perovskite_agent
@@ -28,7 +28,8 @@
2828
from agents.matmaster_agent.superconductor_agent.agent import init_superconductor_agent
2929
from agents.matmaster_agent.thermoelectric_agent.agent import init_thermoelectric_agent
3030
from agents.matmaster_agent.traj_analysis_agent.agent import init_traj_analysis_agent
31-
from agents.matmaster_agent.utils.event_utils import send_error_event
31+
from agents.matmaster_agent.utils.event_utils import send_error_event, frontend_text_event
32+
from agents.matmaster_agent.utils.helper_func import update_session_state
3233

3334
logging.getLogger('google_adk.google.adk.tools.base_authenticated_tool').setLevel(logging.ERROR)
3435

@@ -86,6 +87,12 @@ async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event,
8687
try:
8788
# Delegate to parent implementation for the actual processing
8889
async for event in super()._run_async_impl(ctx):
90+
# 对于 [matmaster_check_job_status] 生成的消息, 手动拼一个流式消息
91+
if ctx.session.state['special_llm_response']:
92+
yield frontend_text_event(ctx, self.name, event.content.parts[0].text, ModelRole)
93+
94+
ctx.session.state['special_llm_response'] = False
95+
await update_session_state(ctx, self.name)
8996
yield event
9097
except BaseException as err:
9198
async for error_event in send_error_event(err, ctx, self.name):

agents/matmaster_agent/callback.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from google.adk.agents.callback_context import CallbackContext
99
from google.adk.models import LlmResponse
1010
from google.genai import types
11-
from google.genai.types import FunctionCall, Part, FunctionResponse
11+
from google.genai.types import FunctionCall, Part
1212

1313
from agents.matmaster_agent.base_agents.callback import _get_ak
1414
from agents.matmaster_agent.constant import FRONTEND_STATE_KEY
@@ -40,6 +40,7 @@ async def matmaster_prepare_state(callback_context: CallbackContext) -> Optional
4040
callback_context.state['sync_tools'] = callback_context.state.get('sync_tools', None)
4141
callback_context.state['invocation_id_with_tool_call'] = callback_context.state.get('invocation_id_with_tool_call',
4242
None)
43+
callback_context.state['special_llm_response'] = False
4344

4445

4546
async def matmaster_set_lang(callback_context: CallbackContext) -> Optional[types.Content]:
@@ -72,7 +73,11 @@ async def matmaster_check_job_status(callback_context: CallbackContext, llm_resp
7273
for origin_job_id, job_id, job_query_url, agent_name in running_job_ids:
7374
job_status = get_job_status(job_query_url, access_key=access_key)
7475
if job_status in ['Failed', 'Finished']:
76+
if llm_response.partial: # 原来消息的流式版本置空 None
77+
llm_response.content = None
78+
break
7579
if not reset:
80+
callback_context.state['special_llm_response'] = True # 标记开始处理原来消息的非流式版本
7681
llm_response.content.parts = []
7782
reset = True
7883
logger.info(f"[matmaster_check_job_status] job_id = {job_id}, job_status = {job_status}")

0 commit comments

Comments
 (0)