Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions agents/matmaster_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from agents.matmaster_agent.MrDice_agent.agent import init_MrDice_agent
from agents.matmaster_agent.apex_agent.agent import init_apex_agent
from agents.matmaster_agent.base_agents.io_agent import HandleFileUploadLlmAgent
from agents.matmaster_agent.callback import matmaster_prepare_state, matmaster_check_transfer
from agents.matmaster_agent.callback import matmaster_prepare_state, matmaster_check_transfer, matmaster_set_lang, \
matmaster_check_job_status
from agents.matmaster_agent.chembrain_agent.agent import init_chembrain_agent
from agents.matmaster_agent.constant import MATMASTER_AGENT_NAME
from agents.matmaster_agent.llm_config import MatMasterLlmConfig
Expand Down Expand Up @@ -76,8 +77,8 @@ def __init__(self, llm_config):
global_instruction=GlobalInstruction,
instruction=AgentInstruction,
description=AgentDescription,
before_agent_callback=matmaster_prepare_state,
after_model_callback=matmaster_check_transfer,
before_agent_callback=[matmaster_prepare_state, matmaster_set_lang],
after_model_callback=[matmaster_check_job_status, matmaster_check_transfer],
)

@override
Expand Down
2 changes: 1 addition & 1 deletion agents/matmaster_agent/base_agents/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def default_before_tool_callback(tool, args, tool_context):


@check_None_wrapper
def _get_ak(ctx: Union[InvocationContext, ToolContext]):
def _get_ak(ctx: Union[InvocationContext, ToolContext, CallbackContext]):
session_state = get_session_state(ctx)
return session_state[FRONTEND_STATE_KEY]['biz'].get('ak') or os.getenv('BOHRIUM_ACCESS_KEY')

Expand Down
20 changes: 14 additions & 6 deletions agents/matmaster_agent/base_agents/job_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ModelRole,
get_BohriumExecutor,
get_BohriumStorage,
get_DFlowExecutor,
get_DFlowExecutor, OpenAPIJobAPI,
)
from agents.matmaster_agent.llm_config import MatMasterLlmConfig
from agents.matmaster_agent.model import BohrJobInfo, DFlowJobInfo
Expand Down Expand Up @@ -318,10 +318,14 @@ async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event,
job_status = results['status']
if not ctx.session.state['dflow']:
bohr_job_id = results['extra_info']['bohr_job_id']
job_query_url = f'{OpenAPIJobAPI}/{bohr_job_id}'
job_detail_url = results['extra_info']['job_link']
frontend_result = BohrJobInfo(origin_job_id=origin_job_id, job_name=job_name,
job_status=job_status, job_detail_url=job_detail_url,
job_id=bohr_job_id).model_dump(mode='json')
job_status=job_status, job_query_url=job_query_url,
job_detail_url=job_detail_url,
job_id=bohr_job_id,
agent_name=ctx.agent.parent_agent.parent_agent.name).model_dump(
mode='json')
else:
workflow_id = results['extra_info']['workflow_id']
workflow_uid = results['extra_info']['workflow_uid']
Expand Down Expand Up @@ -631,9 +635,13 @@ async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event,
yield result_event

if (
session_state[FRONTEND_STATE_KEY]['biz'].get('origin_id', None) is not None and
list(session_state['long_running_jobs'].keys()) and
session_state[FRONTEND_STATE_KEY]['biz']['origin_id'] in list(session_state['long_running_jobs'].keys())
session_state.get('origin_job_id', None) is not None or
(
session_state[FRONTEND_STATE_KEY]['biz'].get('origin_id', None) is not None and
list(session_state['long_running_jobs'].keys()) and
session_state[FRONTEND_STATE_KEY]['biz']['origin_id'] in list(
session_state['long_running_jobs'].keys())
)
): # Only Query Job Result
pass
else:
Expand Down
40 changes: 38 additions & 2 deletions agents/matmaster_agent/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@

import litellm
from google.adk.agents.callback_context import CallbackContext
from google.adk.models import LlmResponse
from google.adk.models import LlmResponse, LlmRequest
from google.genai import types
from google.genai.types import FunctionCall, Part
from google.genai.types import FunctionCall, Part, FunctionResponse

from agents.matmaster_agent.base_agents.callback import _get_ak
from agents.matmaster_agent.constant import FRONTEND_STATE_KEY
from agents.matmaster_agent.model import TransferCheck, UserContent
from agents.matmaster_agent.prompt import get_transfer_check_prompt, get_user_content_lang
from agents.matmaster_agent.utils.job_utils import get_job_status, has_job_running, get_running_jobs_detail
from agents.matmaster_agent.utils.llm_response_utils import has_function_call

logger = logging.getLogger(__name__)
Expand All @@ -22,6 +24,7 @@
async def matmaster_prepare_state(callback_context: CallbackContext) -> Optional[types.Content]:
callback_context.state['current_time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
callback_context.state['error_occurred'] = False
callback_context.state['origin_job_id'] = None

callback_context.state[FRONTEND_STATE_KEY] = callback_context.state.get(FRONTEND_STATE_KEY, {})
callback_context.state[FRONTEND_STATE_KEY]['biz'] = callback_context.state[FRONTEND_STATE_KEY].get('biz', {})
Expand All @@ -38,6 +41,8 @@ async def matmaster_prepare_state(callback_context: CallbackContext) -> Optional
callback_context.state['invocation_id_with_tool_call'] = callback_context.state.get('invocation_id_with_tool_call',
None)


async def matmaster_set_lang(callback_context: CallbackContext) -> Optional[types.Content]:
user_content = callback_context.user_content.parts[0].text
prompt = get_user_content_lang().format(user_content=user_content)
response = litellm.completion(model='azure/gpt-4o', messages=[{'role': 'user', 'content': prompt}],
Expand All @@ -48,6 +53,37 @@ async def matmaster_prepare_state(callback_context: CallbackContext) -> Optional
callback_context.state['target_language'] = language


async def matmaster_check_job_status(callback_context: CallbackContext, llm_response: LlmRequest) -> Optional[
LlmResponse]:
if (
(jobs_dict := callback_context.state['long_running_jobs']) and
has_job_running(jobs_dict)
):
running_job_ids = get_running_jobs_detail(jobs_dict)
access_key = _get_ak(callback_context)
if callback_context.state['target_language'] in ['Chinese']:
job_complete_intro = '检测到任务 <{job_id}> 已完成,我将立刻转移至对应的 Agent 去获取任务结果,请稍等...'
else:
job_complete_intro = ('Job <{job_id}> has been detected as completed. '
'I will immediately transfer to the corresponding agent to retrieve the job results. Please wait...')

for origin_job_id, job_id, job_query_url, agent_name in running_job_ids:
job_status = get_job_status(job_query_url, access_key=access_key)
if job_status in ['Failed', 'Finished']:
logger.info(f"[matmaster_check_job_status] job_id = {job_id}, job_status = {job_status}")
function_call_id = f"call_{str(uuid.uuid4()).replace('-', '')[:24]}"
callback_context.state['origin_job_id'] = origin_job_id
llm_response.content.parts.insert(0, Part(text=job_complete_intro.format(job_id=job_id),
function_call=FunctionCall(id=function_call_id,
name='transfer_to_agent',
args={'agent_name': agent_name}),
function_response=FunctionResponse(id=function_call_id,
name='transfer_to_agent',
response=None)
)
)


# after_model_callback
async def matmaster_check_transfer(callback_context: CallbackContext, llm_response: LlmResponse) -> Optional[
LlmResponse]:
Expand Down
1 change: 1 addition & 0 deletions agents/matmaster_agent/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
DFLOW_HOST = 'https://workflows.deepmodeling.com'
DFLOW_K8S_API_SERVER = 'https://workflows.deepmodeling.com'
BOHRIUM_API_URL = 'https://bohrium-api.dp.tech'
OpenAPIJobAPI = f"{OPENAPI_HOST}/openapi/v1/job"

DFlowExecutor = {
'type': 'local',
Expand Down
2 changes: 2 additions & 0 deletions agents/matmaster_agent/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ class JobResult(BaseModel):
class BohrJobInfo(BaseModel):
origin_job_id: str
job_id: int
job_query_url: str
job_detail_url: str
job_status: JobStatus
job_name: str
job_result: Optional[List[JobResult]] = None
job_in_ctx: bool = False
agent_name: str

@field_validator('job_detail_url')
@classmethod
Expand Down
34 changes: 34 additions & 0 deletions agents/matmaster_agent/utils/job_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import requests

from agents.matmaster_agent.model import JobStatus, BohrJobInfo


def mapping_status(status):
return {
-1: 'Failed',
0: 'Pending',
1: 'Running',
2: 'Finished',
3: 'Scheduling',
4: 'Stopping',
5: 'Stopped',
6: 'Terminating',
7: 'Killing',
8: 'Uploading',
9: 'Wait'
}.get(status, 'Unknown')


def get_job_status(job_query_url, access_key):
response = requests.request('GET', job_query_url, headers={'accessKey': access_key})
return mapping_status(response.json()['data']['status'])


def has_job_running(jobs_dict: BohrJobInfo) -> bool:
"""检查是否有任何作业处于运行状态"""
return any(job['job_status'] == JobStatus.Running for job in jobs_dict.values())


def get_running_jobs_detail(jobs_dict: BohrJobInfo):
return [(job['origin_job_id'], job['job_id'], job['job_query_url'], job['agent_name']) for job in jobs_dict.values()
if job['job_status'] == JobStatus.Running]
Loading