Skip to content

Commit 67d0505

Browse files
authored
Merge pull request #162 from AnguseZhang/dev/zhouh
feat: enhance MatMasterAgent with job status monitoring, language callback, and code style unification
2 parents fe86de0 + 0fee951 commit 67d0505

7 files changed

Lines changed: 94 additions & 12 deletions

File tree

agents/matmaster_agent/agent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from agents.matmaster_agent.MrDice_agent.agent import init_MrDice_agent
1515
from agents.matmaster_agent.apex_agent.agent import init_apex_agent
1616
from agents.matmaster_agent.base_agents.io_agent import HandleFileUploadLlmAgent
17-
from agents.matmaster_agent.callback import matmaster_prepare_state, matmaster_check_transfer
17+
from agents.matmaster_agent.callback import matmaster_prepare_state, matmaster_check_transfer, matmaster_set_lang, \
18+
matmaster_check_job_status
1819
from agents.matmaster_agent.chembrain_agent.agent import init_chembrain_agent
1920
from agents.matmaster_agent.constant import MATMASTER_AGENT_NAME
2021
from agents.matmaster_agent.llm_config import MatMasterLlmConfig
@@ -76,8 +77,8 @@ def __init__(self, llm_config):
7677
global_instruction=GlobalInstruction,
7778
instruction=AgentInstruction,
7879
description=AgentDescription,
79-
before_agent_callback=matmaster_prepare_state,
80-
after_model_callback=matmaster_check_transfer,
80+
before_agent_callback=[matmaster_prepare_state, matmaster_set_lang],
81+
after_model_callback=[matmaster_check_job_status, matmaster_check_transfer],
8182
)
8283

8384
@override

agents/matmaster_agent/base_agents/callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ async def default_before_tool_callback(tool, args, tool_context):
118118

119119

120120
@check_None_wrapper
121-
def _get_ak(ctx: Union[InvocationContext, ToolContext]):
121+
def _get_ak(ctx: Union[InvocationContext, ToolContext, CallbackContext]):
122122
session_state = get_session_state(ctx)
123123
return session_state[FRONTEND_STATE_KEY]['biz'].get('ak') or os.getenv('BOHRIUM_ACCESS_KEY')
124124

agents/matmaster_agent/base_agents/job_agent.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
ModelRole,
3232
get_BohriumExecutor,
3333
get_BohriumStorage,
34-
get_DFlowExecutor,
34+
get_DFlowExecutor, OpenAPIJobAPI,
3535
)
3636
from agents.matmaster_agent.llm_config import MatMasterLlmConfig
3737
from agents.matmaster_agent.model import BohrJobInfo, DFlowJobInfo
@@ -318,10 +318,14 @@ async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event,
318318
job_status = results['status']
319319
if not ctx.session.state['dflow']:
320320
bohr_job_id = results['extra_info']['bohr_job_id']
321+
job_query_url = f'{OpenAPIJobAPI}/{bohr_job_id}'
321322
job_detail_url = results['extra_info']['job_link']
322323
frontend_result = BohrJobInfo(origin_job_id=origin_job_id, job_name=job_name,
323-
job_status=job_status, job_detail_url=job_detail_url,
324-
job_id=bohr_job_id).model_dump(mode='json')
324+
job_status=job_status, job_query_url=job_query_url,
325+
job_detail_url=job_detail_url,
326+
job_id=bohr_job_id,
327+
agent_name=ctx.agent.parent_agent.parent_agent.name).model_dump(
328+
mode='json')
325329
else:
326330
workflow_id = results['extra_info']['workflow_id']
327331
workflow_uid = results['extra_info']['workflow_uid']
@@ -631,9 +635,13 @@ async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event,
631635
yield result_event
632636

633637
if (
634-
session_state[FRONTEND_STATE_KEY]['biz'].get('origin_id', None) is not None and
635-
list(session_state['long_running_jobs'].keys()) and
636-
session_state[FRONTEND_STATE_KEY]['biz']['origin_id'] in list(session_state['long_running_jobs'].keys())
638+
session_state.get('origin_job_id', None) is not None or
639+
(
640+
session_state[FRONTEND_STATE_KEY]['biz'].get('origin_id', None) is not None and
641+
list(session_state['long_running_jobs'].keys()) and
642+
session_state[FRONTEND_STATE_KEY]['biz']['origin_id'] in list(
643+
session_state['long_running_jobs'].keys())
644+
)
637645
): # Only Query Job Result
638646
pass
639647
else:

agents/matmaster_agent/callback.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
import litellm
88
from google.adk.agents.callback_context import CallbackContext
9-
from google.adk.models import LlmResponse
9+
from google.adk.models import LlmResponse, LlmRequest
1010
from google.genai import types
11-
from google.genai.types import FunctionCall, Part
11+
from google.genai.types import FunctionCall, Part, FunctionResponse
1212

13+
from agents.matmaster_agent.base_agents.callback import _get_ak
1314
from agents.matmaster_agent.constant import FRONTEND_STATE_KEY
1415
from agents.matmaster_agent.model import TransferCheck, UserContent
1516
from agents.matmaster_agent.prompt import get_transfer_check_prompt, get_user_content_lang
17+
from agents.matmaster_agent.utils.job_utils import get_job_status, has_job_running, get_running_jobs_detail
1618
from agents.matmaster_agent.utils.llm_response_utils import has_function_call
1719

1820
logger = logging.getLogger(__name__)
@@ -22,6 +24,7 @@
2224
async def matmaster_prepare_state(callback_context: CallbackContext) -> Optional[types.Content]:
2325
callback_context.state['current_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
2426
callback_context.state['error_occurred'] = False
27+
callback_context.state['origin_job_id'] = None
2528

2629
callback_context.state[FRONTEND_STATE_KEY] = callback_context.state.get(FRONTEND_STATE_KEY, {})
2730
callback_context.state[FRONTEND_STATE_KEY]['biz'] = callback_context.state[FRONTEND_STATE_KEY].get('biz', {})
@@ -38,6 +41,8 @@ async def matmaster_prepare_state(callback_context: CallbackContext) -> Optional
3841
callback_context.state['invocation_id_with_tool_call'] = callback_context.state.get('invocation_id_with_tool_call',
3942
None)
4043

44+
45+
async def matmaster_set_lang(callback_context: CallbackContext) -> Optional[types.Content]:
4146
user_content = callback_context.user_content.parts[0].text
4247
prompt = get_user_content_lang().format(user_content=user_content)
4348
response = litellm.completion(model='azure/gpt-4o', messages=[{'role': 'user', 'content': prompt}],
@@ -48,6 +53,37 @@ async def matmaster_prepare_state(callback_context: CallbackContext) -> Optional
4853
callback_context.state['target_language'] = language
4954

5055

56+
async def matmaster_check_job_status(callback_context: CallbackContext, llm_response: LlmRequest) -> Optional[
57+
LlmResponse]:
58+
if (
59+
(jobs_dict := callback_context.state['long_running_jobs']) and
60+
has_job_running(jobs_dict)
61+
):
62+
running_job_ids = get_running_jobs_detail(jobs_dict)
63+
access_key = _get_ak(callback_context)
64+
if callback_context.state['target_language'] in ['Chinese']:
65+
job_complete_intro = '检测到任务 <{job_id}> 已完成,我将立刻转移至对应的 Agent 去获取任务结果,请稍等...'
66+
else:
67+
job_complete_intro = ('Job <{job_id}> has been detected as completed. '
68+
'I will immediately transfer to the corresponding agent to retrieve the job results. Please wait...')
69+
70+
for origin_job_id, job_id, job_query_url, agent_name in running_job_ids:
71+
job_status = get_job_status(job_query_url, access_key=access_key)
72+
if job_status in ['Failed', 'Finished']:
73+
logger.info(f"[matmaster_check_job_status] job_id = {job_id}, job_status = {job_status}")
74+
function_call_id = f"call_{str(uuid.uuid4()).replace('-', '')[:24]}"
75+
callback_context.state['origin_job_id'] = origin_job_id
76+
llm_response.content.parts.insert(0, Part(text=job_complete_intro.format(job_id=job_id),
77+
function_call=FunctionCall(id=function_call_id,
78+
name='transfer_to_agent',
79+
args={'agent_name': agent_name}),
80+
function_response=FunctionResponse(id=function_call_id,
81+
name='transfer_to_agent',
82+
response=None)
83+
)
84+
)
85+
86+
5187
# after_model_callback
5288
async def matmaster_check_transfer(callback_context: CallbackContext, llm_response: LlmResponse) -> Optional[
5389
LlmResponse]:

agents/matmaster_agent/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
DFLOW_HOST = 'https://workflows.deepmodeling.com'
6969
DFLOW_K8S_API_SERVER = 'https://workflows.deepmodeling.com'
7070
BOHRIUM_API_URL = 'https://bohrium-api.dp.tech'
71+
OpenAPIJobAPI = f"{OPENAPI_HOST}/openapi/v1/job"
7172

7273
DFlowExecutor = {
7374
'type': 'local',

agents/matmaster_agent/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ class JobResult(BaseModel):
2626
class BohrJobInfo(BaseModel):
2727
origin_job_id: str
2828
job_id: int
29+
job_query_url: str
2930
job_detail_url: str
3031
job_status: JobStatus
3132
job_name: str
3233
job_result: Optional[List[JobResult]] = None
3334
job_in_ctx: bool = False
35+
agent_name: str
3436

3537
@field_validator('job_detail_url')
3638
@classmethod
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import requests
2+
3+
from agents.matmaster_agent.model import JobStatus, BohrJobInfo
4+
5+
6+
def mapping_status(status):
7+
return {
8+
-1: 'Failed',
9+
0: 'Pending',
10+
1: 'Running',
11+
2: 'Finished',
12+
3: 'Scheduling',
13+
4: 'Stopping',
14+
5: 'Stopped',
15+
6: 'Terminating',
16+
7: 'Killing',
17+
8: 'Uploading',
18+
9: 'Wait'
19+
}.get(status, 'Unknown')
20+
21+
22+
def get_job_status(job_query_url, access_key):
23+
response = requests.request('GET', job_query_url, headers={'accessKey': access_key})
24+
return mapping_status(response.json()['data']['status'])
25+
26+
27+
def has_job_running(jobs_dict: BohrJobInfo) -> bool:
28+
"""检查是否有任何作业处于运行状态"""
29+
return any(job['job_status'] == JobStatus.Running for job in jobs_dict.values())
30+
31+
32+
def get_running_jobs_detail(jobs_dict: BohrJobInfo):
33+
return [(job['origin_job_id'], job['job_id'], job['job_query_url'], job['agent_name']) for job in jobs_dict.values()
34+
if job['job_status'] == JobStatus.Running]

0 commit comments

Comments
 (0)