Skip to content

Commit afe13d7

Browse files
Improve FunctionCallAgent (#316)
* remove `AsyncPolicyAgent` * fix type * improve `FunctionCallAgent`
1 parent 0553aaf commit afe13d7

14 files changed

Lines changed: 458 additions & 348 deletions

File tree

lagent/agents/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
StreamingSequential,
1212
)
1313
from .compact_agent import AsyncCompactAgent, estimate_token_count
14-
from .internclaw_agent import (
15-
AsyncEnvAgent,
16-
AsyncPolicyAgent,
17-
InternClawAgent,
18-
)
14+
from .internclaw_agent import AsyncEnvAgent, InternClawAgent
1915

2016
__all__ = [
2117
'Agent',
@@ -31,6 +27,5 @@
3127
'AsyncCompactAgent',
3228
'estimate_token_count',
3329
'AsyncEnvAgent',
34-
'AsyncPolicyAgent',
3530
'InternClawAgent',
3631
]

lagent/agents/aggregator/default_aggregator.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
class DefaultAggregator:
99

10-
def aggregate(self,
11-
messages: Memory,
12-
name: str,
13-
parser: StrParser = None,
14-
system_instruction: str = None,
15-
tools: List[Dict] = None,
16-
) -> Tuple[List[Dict[str, str]], Optional[List[Dict]]]:
10+
def aggregate(
11+
self,
12+
messages: Memory,
13+
name: str,
14+
parser: StrParser = None,
15+
system_instruction: str = None,
16+
tools: List[Dict] = None,
17+
) -> Tuple[List[Dict[str, str]], Optional[List[Dict]]]:
1718
_message = []
1819
messages = messages.get_memory()
1920
if system_instruction:
@@ -38,12 +39,17 @@ def aggregate(self,
3839
)
3940
)
4041
else:
41-
if len(_message) > 0 and _message[-1]['role'] == 'user':
42+
if (
43+
len(_message) > 0
44+
and _message[-1]['role'] == 'user'
45+
and isinstance(_message[-1]['content'], str)
46+
and isinstance(user_message, str)
47+
):
4248
_message[-1]['content'] += user_message
4349
_message[-1]['extra_info'] = extra_info
4450
else:
4551
_message.append(dict(role='user', content=user_message, extra_info=extra_info))
46-
52+
4753
latest_env_info = None
4854
for message in messages:
4955
if getattr(message, 'env_info', None) is not None:
@@ -52,7 +58,7 @@ def aggregate(self,
5258
tools_to_use = tools
5359
if latest_env_info and latest_env_info.get("tools"):
5460
tools_to_use = latest_env_info.get("tools")
55-
61+
5662
return _message, tools_to_use
5763

5864
@staticmethod

lagent/agents/fc_agent.py

Lines changed: 132 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import asyncio
22
import json
3+
import logging
4+
import platform
35
from copy import deepcopy
46
from dataclasses import asdict
5-
from typing import Dict, List, Literal, Optional, Union
7+
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
68

79
from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed
810

911
from lagent.actions import AsyncActionExecutor
1012
from lagent.hooks import Hook
1113
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode, AgentMessage, AgentStatusCode
14+
from lagent.skills.skills import SkillsLoader
1215
from lagent.utils import create_object, truncate_text
1316
from .agent import AsyncAgent
1417

18+
logger = logging.getLogger("lagent.agents.fc_agent")
19+
1520
DEFAULT_TOOL_TEMPLATE = """# Tools
1621
1722
You may call one or more functions to assist with the user query.
@@ -27,7 +32,9 @@
2732
</tool_call>"""
2833

2934

30-
def get_tool_prompt(actions: list, exclude_arguments: list = None, template: str = DEFAULT_TOOL_TEMPLATE) -> str:
35+
def get_tool_prompt(
36+
actions: list, exclude_arguments: list = None, to_string: bool = True, template: str = DEFAULT_TOOL_TEMPLATE
37+
) -> Union[str, List[dict]]:
3138
exclude_arguments = exclude_arguments or ['session_id']
3239

3340
def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') -> dict:
@@ -57,60 +64,138 @@ def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') ->
5764
tools.append(_convert_tool_schema(api, f"{action.name}.{{}}"))
5865
else:
5966
tools.append(_convert_tool_schema(action_desc))
67+
if not to_string:
68+
return tools
6069
return template.format(tools='\n'.join([json.dumps(tool, ensure_ascii=False) for tool in tools]))
6170

6271

6372
class FunctionCallAgent(AsyncAgent):
6473
def __init__(
6574
self,
66-
select_agent: Union[Dict, AsyncAgent],
75+
policy_agent: Union[Dict, AsyncAgent],
6776
env_agent: Union[Dict, AsyncAgent],
68-
finish_condition: callable = lambda x, _: x and not x.tool_calls,
77+
compact_agent: Optional[Dict] = None,
78+
consolidate_agent: Optional[Dict] = None,
79+
finish_condition: callable = lambda m, _: m and not m.tool_calls,
6980
max_turn: Optional[int] = None,
81+
initialize_input: bool = True,
7082
name: Optional[str] = None,
7183
):
7284
super().__init__(name=name)
73-
self.select_agent = create_object(select_agent)
85+
self.policy_agent = create_object(policy_agent)
7486
self.env_agent = create_object(env_agent)
87+
self.compact_agent = create_object(compact_agent)
88+
self.consolidate_agent = create_object(consolidate_agent)
7589
self.finish_condition = finish_condition
7690
self.max_turn = max_turn
91+
self.initialize_input = initialize_input
7792

78-
async def forward(self, env_message: AgentMessage, session_id: str | int, **kwargs):
79-
selection_message: AgentMessage = None
93+
async def forward(self, env_message: AgentMessage, **kwargs):
94+
policy_message: AgentMessage = None
8095
current_turn = 0
81-
while (self.finish_condition is None or not self.finish_condition(selection_message, env_message)) and (
96+
if self.initialize_input:
97+
env_message = await self.env_agent(env_message, **kwargs)
98+
99+
while (self.finish_condition is None or not self.finish_condition(policy_message, env_message)) and (
82100
self.max_turn is None or current_turn < self.max_turn
83101
):
84-
selection_message = await self.select_agent(env_message, session_id=session_id, **kwargs)
85-
if selection_message.stream_state == AgentStatusCode.SERVER_ERR:
102+
policy_message = await self.policy_agent(env_message, **kwargs)
103+
if policy_message.stream_state == AgentStatusCode.SERVER_ERR:
86104
raise ValueError("Rollout response error: state is neither completed nor aborted!")
87-
if selection_message.stream_state == AgentStatusCode.SESSION_OUT_OF_LIMIT:
105+
if policy_message.stream_state == AgentStatusCode.SESSION_OUT_OF_LIMIT:
88106
for _ in range(2): # remove the last two messages
89-
self.select_agent.memory.get(session_id).delete(-1)
107+
self.policy_agent.memory.delete(-1)
90108
return AgentMessage(
91109
sender=self.name,
92-
content='Exceeded context length limit',
93-
finish_reason=selection_message.finish_reason,
110+
content=policy_message.content,
111+
finish_reason=policy_message.finish_reason,
94112
)
95-
if selection_message.finish_reason == 'abort':
96-
return AgentMessage(sender=self.name, content='Aborted request', finish_reason='abort')
97-
env_message = await self.env_agent(selection_message, session_id=session_id)
113+
if policy_message.finish_reason == 'abort':
114+
return AgentMessage(sender=self.name, content=policy_message.content, finish_reason='abort')
115+
116+
# Orchestrator manages memory
117+
await self._maybe_manage_memory(policy_message, env_message)
118+
119+
env_message = await self.env_agent(policy_message)
98120
current_turn += 1
121+
if policy_message is not None:
122+
return AgentMessage(sender=self.name, content=policy_message.content, finish_reason='stop')
99123
return AgentMessage(sender=self.name, content="Finished", finish_reason='stop')
100124

125+
async def _maybe_manage_memory(self, policy_message: AgentMessage, env_message: AgentMessage) -> None:
126+
"""Orchestrate compact and consolidate.
127+
128+
Orchestrator calls policy's aggregator to get formatted_messages,
129+
checks should_compact, and if triggered:
130+
1. Runs consolidate_agent (optional)
131+
2. Runs compact_agent to produce summary
132+
3. Injects summary + boundary into env_message
133+
ContextBuilder reads these on the next turn.
134+
"""
135+
if not self.compact_agent:
136+
return
137+
138+
from lagent.agents.compact_agent import estimate_token_count
139+
140+
state = self.get_messages()
141+
formatted_messages, tools = state['policy_agent.messages'], state['policy_agent.tools']
142+
compact_input = AgentMessage(
143+
sender=self.name,
144+
content=formatted_messages,
145+
extra_info={'context_tokens': estimate_token_count(formatted_messages, tools)},
146+
)
147+
if not (hasattr(self.compact_agent, 'should_compact') and self.compact_agent.should_compact(compact_input)):
148+
return
149+
150+
# 1. Consolidate first (preserve info before compacting)
151+
if self.consolidate_agent:
152+
try:
153+
await self.consolidate_agent(compact_input)
154+
self.consolidate_agent.reset(recursive=True)
155+
logger.info("Consolidation completed")
156+
except Exception:
157+
logger.exception("Consolidation failed, continuing with compact")
158+
# 2. Compact — inject summary + boundary into env_message
159+
try:
160+
summary_msg = await self.compact_agent(compact_input)
161+
self.compact_agent.reset(recursive=True)
162+
if summary_msg and summary_msg.content:
163+
if env_message.env_info is None:
164+
env_message.env_info = {}
165+
env_message.env_info['conversation_summary'] = summary_msg.content
166+
env_message.env_info['compact_boundary'] = len(self.policy_agent.memory.memory)
167+
logger.info("Compact summary injected (%d chars)", len(summary_msg.content))
168+
except Exception:
169+
logger.exception("Compact failed")
170+
171+
172+
class MemoryProvider(Protocol):
173+
async def get_info(self) -> dict:
174+
"""Return long-term memory info for EnvAgent's env_info. The content and format are flexible, but should be concise."""
175+
...
176+
101177

102178
class EnvAgent(AsyncAgent):
103179
def __init__(
104180
self,
105-
actions: list,
181+
actions,
182+
skills: Optional[SkillsLoader] = None,
183+
long_term_memory: Optional[MemoryProvider] = None,
106184
stateful_tools: List[str] = None,
107185
max_tool_response_length: int = None,
108186
tool_response_truncate_side: Literal['left', 'right', 'middle'] = 'middle',
109187
action_hooks: List[Union[dict, Hook]] = None,
110188
name: Optional[str] = None,
111189
):
112190
super().__init__(name=name)
113-
self.actions = AsyncActionExecutor(actions, hooks=action_hooks)
191+
if isinstance(actions, AsyncActionExecutor):
192+
for action_hook in action_hooks or []:
193+
actions.register_hook(create_object(action_hook))
194+
self.actions = actions
195+
else:
196+
self.actions = AsyncActionExecutor(actions, hooks=action_hooks)
197+
self.skills = create_object(skills)
198+
self.long_term_memory = create_object(long_term_memory)
114199
self.stateful_tools = stateful_tools or []
115200
self.max_tool_response_length = max_tool_response_length
116201
self.tool_response_truncate_side = tool_response_truncate_side
@@ -124,41 +209,56 @@ def __init__(
124209
retry_error_callback=lambda retry_state: retry_state.outcome.result(),
125210
)
126211

127-
async def forward(self, selection_message: AgentMessage, session_id: str | int, **kwargs):
128-
if not selection_message.tool_calls:
129-
return AgentMessage(sender=self.name, content='No tool call')
212+
async def get_env_info(self) -> Dict[str, Any]:
213+
env_info = {'skills': '', 'active_skills': '', 'memory': '', 'tools': [], 'runtime': {}}
214+
if self.skills is not None:
215+
env_info['skills'] = await self.skills.build_skills_summary()
216+
always_skills = await self.skills.get_always_skills()
217+
if always_skills:
218+
env_info['active_skills'] = await self.skills.load_skills_for_context(always_skills)
219+
if self.long_term_memory is not None:
220+
env_info['memory'] = await self.long_term_memory.get_info()
221+
if self.actions:
222+
env_info['tools'] = get_tool_prompt(list(self.actions.actions.values()), to_string=False)
223+
for name in ['system', 'machine', 'python_version']:
224+
env_info['runtime'][name] = getattr(platform, name)()
225+
return env_info
226+
227+
async def forward(self, message: AgentMessage, **kwargs):
228+
if not message.tool_calls:
229+
return AgentMessage(sender=self.name, content=message.content, env_info=await self.get_env_info())
130230

131231
tool_responses = await asyncio.gather(
132-
*[
133-
self._retry_mechanism(self.execute_tool)(tool_call, session_id)
134-
for tool_call in selection_message.tool_calls
135-
]
232+
*[self._retry_mechanism(self.execute_tool)(tool_call) for tool_call in message.tool_calls]
136233
)
137-
for tool_call_id, tool_response in zip(selection_message.tool_calls_ids, tool_responses):
234+
for tool_call_id, tool_response in zip(message.tool_calls_ids, tool_responses):
138235
tool_response.tool_call_id = tool_call_id
139236
res = tool_response.format_result()
140237
if self.max_tool_response_length is not None and len(res) > self.max_tool_response_length:
141238
res = truncate_text(res, max_num=self.max_tool_response_length, side=self.tool_response_truncate_side)
142239
tool_response.result = [{'type': 'text', 'content': res}]
143-
return AgentMessage(sender=self.name, content=[asdict(resp) for resp in tool_responses])
240+
return AgentMessage(
241+
sender=self.name, content=[asdict(resp) for resp in tool_responses], env_info=await self.get_env_info()
242+
)
144243

145-
async def execute_tool(self, tool_call: dict, session_id: str | int) -> ActionReturn:
244+
async def execute_tool(self, tool_call: dict) -> ActionReturn:
245+
tool_call = deepcopy(tool_call)
146246
try:
247+
if 'function' in tool_call:
248+
tool_call = tool_call['function']
147249
if tool_call['name'].split('.', 1)[0] not in self.actions:
148250
return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Tool {tool_call["name"]} Not Found')
149251
if isinstance(tool_call['arguments'], str):
150252
tool_call['arguments'] = json.loads(tool_call['arguments'])
151253
if tool_call['name'] in self.stateful_tools:
152-
tool_call = deepcopy(tool_call)
153-
tool_call['arguments']['session_id'] = session_id
254+
tool_call['arguments']['session_id'] = str(id(self))
154255
except Exception as e:
155256
return ActionReturn(valid=ActionValidCode.INVALID, errmsg=f'Invalid tool call format: {str(e)}')
156257
tool_response: ActionReturn = (
157258
await self.actions(
158259
AgentMessage(
159260
sender='assistant', content=dict(name=tool_call['name'], parameters=tool_call['arguments'])
160261
),
161-
session_id=session_id,
162262
)
163263
).content
164264
return tool_response

lagent/agents/internclaw_agent.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -56,23 +56,6 @@ def _convert_tool_schema(action_description: dict, name_pattern: str = '{}') ->
5656
return tools
5757

5858

59-
class AsyncPolicyAgent(AsyncAgent):
60-
61-
async def forward(self, *message, **kwargs):
62-
formatted_messages, tools = self.aggregator.aggregate(
63-
self.memory, self.name, self.output_format, self.template
64-
)
65-
llm_response = await self.llm.chat(formatted_messages, tools=tools, **kwargs)
66-
# message = AgentMessage(
67-
# sender=self.name,
68-
# content=llm_response.get('content') or '',
69-
# tool_calls=llm_response.get('tool_calls') or [],
70-
# reasoning_content=llm_response.get('reasoning_content'),
71-
# )
72-
# return message
73-
return llm_response
74-
75-
7659
class AsyncEnvAgent(AsyncAgent):
7760
def __init__(self, actions, skills: SkillsLoader = None, long_term_memory=None, **kwargs):
7861
super().__init__(**kwargs)
@@ -168,13 +151,11 @@ async def _inner_func(tool_call):
168151
result_dict['tool_call_id'] = tc.get('id', '')
169152
if resp.valid != ActionValidCode.OPEN:
170153
result_dict['errmsg'] = (
171-
f'Tool Call Error: {resp.errmsg} in tool call '
172-
f'{json.dumps(tc, ensure_ascii=False)}'
154+
f'Tool Call Error: {resp.errmsg} in tool call ' f'{json.dumps(tc, ensure_ascii=False)}'
173155
)
174156
elif resp.state != ActionStatusCode.SUCCESS:
175157
result_dict['errmsg'] = (
176-
f'Tool Call Error: {resp.errmsg} in tool call '
177-
f'{json.dumps(tc, ensure_ascii=False)}'
158+
f'Tool Call Error: {resp.errmsg} in tool call ' f'{json.dumps(tc, ensure_ascii=False)}'
178159
)
179160
if resp.state == ActionStatusCode.ARGS_ERROR:
180161
reward = -1
@@ -352,7 +333,7 @@ async def main():
352333

353334
# ── 4. Policy agent ──
354335
aggregator = InternClawContextBuilder(workspace, tools=None)
355-
policy = AsyncPolicyAgent(
336+
policy = AsyncAgent(
356337
llm=model,
357338
aggregator=aggregator,
358339
hooks=[logger_hook],
@@ -374,7 +355,7 @@ async def main():
374355
)
375356

376357
# ── 7. Consolidate agent (standard InternClawAgent) ──
377-
consolidate_policy = AsyncPolicyAgent(
358+
consolidate_policy = AsyncAgent(
378359
name='consolidate_policy',
379360
llm=model,
380361
template=CONSOLIDATION_PROMPT,

0 commit comments

Comments
 (0)