11import asyncio
22import json
3+ import logging
4+ import platform
35from copy import deepcopy
46from dataclasses import asdict
5- from typing import Dict , List , Literal , Optional , Union
7+ from typing import Any , Dict , List , Literal , Optional , Protocol , Union
68
79from tenacity import retry , retry_if_result , stop_after_attempt , wait_fixed
810
911from lagent .actions import AsyncActionExecutor
1012from lagent .hooks import Hook
1113from lagent .schema import ActionReturn , ActionStatusCode , ActionValidCode , AgentMessage , AgentStatusCode
14+ from lagent .skills .skills import SkillsLoader
1215from lagent .utils import create_object , truncate_text
1316from .agent import AsyncAgent
1417
18+ logger = logging .getLogger ("lagent.agents.fc_agent" )
19+
1520DEFAULT_TOOL_TEMPLATE = """# Tools
1621
1722You may call one or more functions to assist with the user query.
2732</tool_call>"""
2833
2934
30- def get_tool_prompt (actions : list , exclude_arguments : list = None , template : str = DEFAULT_TOOL_TEMPLATE ) -> str :
35+ def get_tool_prompt (
36+ actions : list , exclude_arguments : list = None , to_string : bool = True , template : str = DEFAULT_TOOL_TEMPLATE
37+ ) -> Union [str , List [dict ]]:
3138 exclude_arguments = exclude_arguments or ['session_id' ]
3239
3340 def _convert_tool_schema (action_description : dict , name_pattern : str = '{}' ) -> dict :
@@ -57,60 +64,138 @@ def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') ->
5764 tools .append (_convert_tool_schema (api , f"{ action .name } .{{}}" ))
5865 else :
5966 tools .append (_convert_tool_schema (action_desc ))
67+ if not to_string :
68+ return tools
6069 return template .format (tools = '\n ' .join ([json .dumps (tool , ensure_ascii = False ) for tool in tools ]))
6170
6271
6372class FunctionCallAgent (AsyncAgent ):
6473 def __init__ (
6574 self ,
66- select_agent : Union [Dict , AsyncAgent ],
75+ policy_agent : Union [Dict , AsyncAgent ],
6776 env_agent : Union [Dict , AsyncAgent ],
68- finish_condition : callable = lambda x , _ : x and not x .tool_calls ,
77+ compact_agent : Optional [Dict ] = None ,
78+ consolidate_agent : Optional [Dict ] = None ,
79+ finish_condition : callable = lambda m , _ : m and not m .tool_calls ,
6980 max_turn : Optional [int ] = None ,
81+ initialize_input : bool = True ,
7082 name : Optional [str ] = None ,
7183 ):
7284 super ().__init__ (name = name )
73- self .select_agent = create_object (select_agent )
85+ self .policy_agent = create_object (policy_agent )
7486 self .env_agent = create_object (env_agent )
87+ self .compact_agent = create_object (compact_agent )
88+ self .consolidate_agent = create_object (consolidate_agent )
7589 self .finish_condition = finish_condition
7690 self .max_turn = max_turn
91+ self .initialize_input = initialize_input
7792
78- async def forward (self , env_message : AgentMessage , session_id : str | int , ** kwargs ):
79- selection_message : AgentMessage = None
93+ async def forward (self , env_message : AgentMessage , ** kwargs ):
94+ policy_message : AgentMessage = None
8095 current_turn = 0
81- while (self .finish_condition is None or not self .finish_condition (selection_message , env_message )) and (
96+ if self .initialize_input :
97+ env_message = await self .env_agent (env_message , ** kwargs )
98+
99+ while (self .finish_condition is None or not self .finish_condition (policy_message , env_message )) and (
82100 self .max_turn is None or current_turn < self .max_turn
83101 ):
84- selection_message = await self .select_agent (env_message , session_id = session_id , ** kwargs )
85- if selection_message .stream_state == AgentStatusCode .SERVER_ERR :
102+ policy_message = await self .policy_agent (env_message , ** kwargs )
103+ if policy_message .stream_state == AgentStatusCode .SERVER_ERR :
86104 raise ValueError ("Rollout response error: state is neither completed nor aborted!" )
87- if selection_message .stream_state == AgentStatusCode .SESSION_OUT_OF_LIMIT :
105+ if policy_message .stream_state == AgentStatusCode .SESSION_OUT_OF_LIMIT :
88106 for _ in range (2 ): # remove the last two messages
89- self .select_agent .memory . get ( session_id ) .delete (- 1 )
107+ self .policy_agent .memory .delete (- 1 )
90108 return AgentMessage (
91109 sender = self .name ,
92- content = 'Exceeded context length limit' ,
93- finish_reason = selection_message .finish_reason ,
110+ content = policy_message . content ,
111+ finish_reason = policy_message .finish_reason ,
94112 )
95- if selection_message .finish_reason == 'abort' :
96- return AgentMessage (sender = self .name , content = 'Aborted request' , finish_reason = 'abort' )
97- env_message = await self .env_agent (selection_message , session_id = session_id )
113+ if policy_message .finish_reason == 'abort' :
114+ return AgentMessage (sender = self .name , content = policy_message .content , finish_reason = 'abort' )
115+
116+ # Orchestrator manages memory
117+ await self ._maybe_manage_memory (policy_message , env_message )
118+
119+ env_message = await self .env_agent (policy_message )
98120 current_turn += 1
121+ if policy_message is not None :
122+ return AgentMessage (sender = self .name , content = policy_message .content , finish_reason = 'stop' )
99123 return AgentMessage (sender = self .name , content = "Finished" , finish_reason = 'stop' )
100124
125+ async def _maybe_manage_memory (self , policy_message : AgentMessage , env_message : AgentMessage ) -> None :
126+ """Orchestrate compact and consolidate.
127+
128+ Orchestrator calls policy's aggregator to get formatted_messages,
129+ checks should_compact, and if triggered:
130+ 1. Runs consolidate_agent (optional)
131+ 2. Runs compact_agent to produce summary
132+ 3. Injects summary + boundary into env_message
133+ ContextBuilder reads these on the next turn.
134+ """
135+ if not self .compact_agent :
136+ return
137+
138+ from lagent .agents .compact_agent import estimate_token_count
139+
140+ state = self .get_messages ()
141+ formatted_messages , tools = state ['policy_agent.messages' ], state ['policy_agent.tools' ]
142+ compact_input = AgentMessage (
143+ sender = self .name ,
144+ content = formatted_messages ,
145+ extra_info = {'context_tokens' : estimate_token_count (formatted_messages , tools )},
146+ )
147+ if not (hasattr (self .compact_agent , 'should_compact' ) and self .compact_agent .should_compact (compact_input )):
148+ return
149+
150+ # 1. Consolidate first (preserve info before compacting)
151+ if self .consolidate_agent :
152+ try :
153+ await self .consolidate_agent (compact_input )
154+ self .consolidate_agent .reset (recursive = True )
155+ logger .info ("Consolidation completed" )
156+ except Exception :
157+ logger .exception ("Consolidation failed, continuing with compact" )
158+ # 2. Compact — inject summary + boundary into env_message
159+ try :
160+ summary_msg = await self .compact_agent (compact_input )
161+ self .compact_agent .reset (recursive = True )
162+ if summary_msg and summary_msg .content :
163+ if env_message .env_info is None :
164+ env_message .env_info = {}
165+ env_message .env_info ['conversation_summary' ] = summary_msg .content
166+ env_message .env_info ['compact_boundary' ] = len (self .policy_agent .memory .memory )
167+ logger .info ("Compact summary injected (%d chars)" , len (summary_msg .content ))
168+ except Exception :
169+ logger .exception ("Compact failed" )
170+
171+
172+ class MemoryProvider (Protocol ):
173+ async def get_info (self ) -> dict :
174+ """Return long-term memory info for EnvAgent's env_info. The content and format are flexible, but should be concise."""
175+ ...
176+
101177
102178class EnvAgent (AsyncAgent ):
103179 def __init__ (
104180 self ,
105- actions : list ,
181+ actions ,
182+ skills : Optional [SkillsLoader ] = None ,
183+ long_term_memory : Optional [MemoryProvider ] = None ,
106184 stateful_tools : List [str ] = None ,
107185 max_tool_response_length : int = None ,
108186 tool_response_truncate_side : Literal ['left' , 'right' , 'middle' ] = 'middle' ,
109187 action_hooks : List [Union [dict , Hook ]] = None ,
110188 name : Optional [str ] = None ,
111189 ):
112190 super ().__init__ (name = name )
113- self .actions = AsyncActionExecutor (actions , hooks = action_hooks )
191+ if isinstance (actions , AsyncActionExecutor ):
192+ for action_hook in action_hooks or []:
193+ actions .register_hook (create_object (action_hook ))
194+ self .actions = actions
195+ else :
196+ self .actions = AsyncActionExecutor (actions , hooks = action_hooks )
197+ self .skills = create_object (skills )
198+ self .long_term_memory = create_object (long_term_memory )
114199 self .stateful_tools = stateful_tools or []
115200 self .max_tool_response_length = max_tool_response_length
116201 self .tool_response_truncate_side = tool_response_truncate_side
@@ -124,41 +209,56 @@ def __init__(
124209 retry_error_callback = lambda retry_state : retry_state .outcome .result (),
125210 )
126211
127- async def forward (self , selection_message : AgentMessage , session_id : str | int , ** kwargs ):
128- if not selection_message .tool_calls :
129- return AgentMessage (sender = self .name , content = 'No tool call' )
212+ async def get_env_info (self ) -> Dict [str , Any ]:
213+ env_info = {'skills' : '' , 'active_skills' : '' , 'memory' : '' , 'tools' : [], 'runtime' : {}}
214+ if self .skills is not None :
215+ env_info ['skills' ] = await self .skills .build_skills_summary ()
216+ always_skills = await self .skills .get_always_skills ()
217+ if always_skills :
218+ env_info ['active_skills' ] = await self .skills .load_skills_for_context (always_skills )
219+ if self .long_term_memory is not None :
220+ env_info ['memory' ] = await self .long_term_memory .get_info ()
221+ if self .actions :
222+ env_info ['tools' ] = get_tool_prompt (list (self .actions .actions .values ()), to_string = False )
223+ for name in ['system' , 'machine' , 'python_version' ]:
224+ env_info ['runtime' ][name ] = getattr (platform , name )()
225+ return env_info
226+
227+ async def forward (self , message : AgentMessage , ** kwargs ):
228+ if not message .tool_calls :
229+ return AgentMessage (sender = self .name , content = message .content , env_info = await self .get_env_info ())
130230
131231 tool_responses = await asyncio .gather (
132- * [
133- self ._retry_mechanism (self .execute_tool )(tool_call , session_id )
134- for tool_call in selection_message .tool_calls
135- ]
232+ * [self ._retry_mechanism (self .execute_tool )(tool_call ) for tool_call in message .tool_calls ]
136233 )
137- for tool_call_id , tool_response in zip (selection_message .tool_calls_ids , tool_responses ):
234+ for tool_call_id , tool_response in zip (message .tool_calls_ids , tool_responses ):
138235 tool_response .tool_call_id = tool_call_id
139236 res = tool_response .format_result ()
140237 if self .max_tool_response_length is not None and len (res ) > self .max_tool_response_length :
141238 res = truncate_text (res , max_num = self .max_tool_response_length , side = self .tool_response_truncate_side )
142239 tool_response .result = [{'type' : 'text' , 'content' : res }]
143- return AgentMessage (sender = self .name , content = [asdict (resp ) for resp in tool_responses ])
240+ return AgentMessage (
241+ sender = self .name , content = [asdict (resp ) for resp in tool_responses ], env_info = await self .get_env_info ()
242+ )
144243
145- async def execute_tool (self , tool_call : dict , session_id : str | int ) -> ActionReturn :
244+ async def execute_tool (self , tool_call : dict ) -> ActionReturn :
245+ tool_call = deepcopy (tool_call )
146246 try :
247+ if 'function' in tool_call :
248+ tool_call = tool_call ['function' ]
147249 if tool_call ['name' ].split ('.' , 1 )[0 ] not in self .actions :
148250 return ActionReturn (valid = ActionValidCode .INVALID , errmsg = f'Tool { tool_call ["name" ]} Not Found' )
149251 if isinstance (tool_call ['arguments' ], str ):
150252 tool_call ['arguments' ] = json .loads (tool_call ['arguments' ])
151253 if tool_call ['name' ] in self .stateful_tools :
152- tool_call = deepcopy (tool_call )
153- tool_call ['arguments' ]['session_id' ] = session_id
254+ tool_call ['arguments' ]['session_id' ] = str (id (self ))
154255 except Exception as e :
155256 return ActionReturn (valid = ActionValidCode .INVALID , errmsg = f'Invalid tool call format: { str (e )} ' )
156257 tool_response : ActionReturn = (
157258 await self .actions (
158259 AgentMessage (
159260 sender = 'assistant' , content = dict (name = tool_call ['name' ], parameters = tool_call ['arguments' ])
160261 ),
161- session_id = session_id ,
162262 )
163263 ).content
164264 return tool_response
0 commit comments