Skip to content

Commit 0ab2e2f

Browse files
Refactor function calling (#309)
* refactor * backward compatibility * improve memory serialization and deserialization * improve `AsyncMCPClient` * fix memory removal index * modify agent message interface * ensure forward idempotence when resuming aborted rollout * ensure forward idempotence when resuming aborted rollout * Add UID for messages * fix
1 parent b49a275 commit 0ab2e2f

11 files changed

Lines changed: 1134 additions & 45 deletions

File tree

lagent/actions/mcp_client.py

Lines changed: 410 additions & 0 deletions
Large diffs are not rendered by default.

lagent/actions/web_visitor.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import asyncio
2+
import json
3+
import re
4+
import traceback
5+
import warnings
6+
from typing import Any, List
7+
8+
from transformers import AutoTokenizer
9+
10+
from lagent.actions import AsyncActionMixin, BaseAction
11+
from lagent.schema import ActionStatusCode, ActionValidCode, AgentMessage
12+
from lagent.utils import create_object
13+
14+
15+
def extract_last_json(text: str) -> dict | None:
16+
"""
17+
Extracts the last valid JSON object from a string.
18+
Handles Markdown code blocks (```json ... ```) and raw JSON strings.
19+
"""
20+
try:
21+
# 1. Try to find JSON within Markdown code blocks first
22+
# Look for ```json ... ``` or just ``` ... ```
23+
code_block_pattern = re.compile(r'```(?:json)?\s*(\{.*?\})\s*```', re.DOTALL)
24+
matches = code_block_pattern.findall(text)
25+
if matches:
26+
return json.loads(matches[-1])
27+
28+
# 2. If no code blocks, try to find the last outermost pair of braces
29+
# This regex looks for { ... } lazily but we want the last one.
30+
# A simple approach for nested JSON is tricky with regex,
31+
# so we scan from right to left for the last '}' and find its matching '{'.
32+
33+
stack, end_idx = 0, -1
34+
# Reverse search to find the last valid JSON structure
35+
for i in range(len(text) - 1, -1, -1):
36+
char = text[i]
37+
if char == '}':
38+
if stack == 0:
39+
end_idx = i
40+
stack += 1
41+
elif char == '{':
42+
if stack > 0:
43+
stack -= 1
44+
if stack == 0 and end_idx != -1:
45+
# Found a potential outermost JSON object
46+
candidate = text[i : end_idx + 1]
47+
try:
48+
return json.loads(candidate)
49+
except json.JSONDecodeError:
50+
# If this chunk isn't valid, reset and keep searching backwards
51+
# (or you might decide to stop here depending on strictness)
52+
stack, end_idx = 0, -1
53+
return None
54+
except Exception:
55+
return None
56+
57+
58+
class WebVisitor(AsyncActionMixin, BaseAction):
59+
60+
EXTRACTION_PROMPT = """Please process the following webpage content and user goal to extract relevant information:
61+
62+
## **Webpage Content**
63+
{webpage_content}
64+
65+
## **User Goal**
66+
{goal}
67+
68+
## **Task Guidelines**
69+
1. **Content Scanning for Rationale**: Locate the **specific sections/data** directly related to the user's goal within the webpage content
70+
2. **Key Extraction for Evidence**: Identify and extract the **most relevant information** from the content, you never miss any important information, output the **full original context** of the content as far as possible, it can be more than three paragraphs.
71+
3. **Summary Output for Summary**: Organize into a concise paragraph with logical flow, prioritizing clarity and judge the contribution of the information to the goal.
72+
73+
**Final Output Format using JSON format has "rational", "evidence", "summary" fields**
74+
"""
75+
76+
def __init__(
77+
self,
78+
browse_tool: BaseAction | dict,
79+
llm: Any,
80+
max_browse_attempts: int = 3,
81+
max_extract_attempts: int = 3,
82+
sleep_interval: int = 3,
83+
truncate_browse_response_length: int | None = None,
84+
tokenizer_path: str | None = None,
85+
name: str = 'visit',
86+
):
87+
super().__init__(
88+
description={
89+
'name': name,
90+
'description': 'Visit webpage(s) and return the summary of the content.',
91+
'parameters': [
92+
{
93+
'name': 'url',
94+
'type': ['STRING', 'ARRAY'],
95+
"items": {"type": "string"},
96+
"minItems": 1,
97+
'description': 'The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs.',
98+
},
99+
{'name': 'goal', 'type': 'STRING', 'description': 'The goal of the visit for webpage(s).'},
100+
],
101+
'required': ['url', 'goal'],
102+
}
103+
)
104+
browse_tool = create_object(browse_tool)
105+
assert not browse_tool.is_toolkit and browse_tool.description['required'] == [
106+
'url'
107+
], "browse_tool must be a single-tool action with only 'url' as required argument."
108+
self.browse_tool = browse_tool
109+
self.llm = create_object(llm)
110+
self.max_browse_attempts = max_browse_attempts
111+
self.max_extract_attempts = max_extract_attempts
112+
self.sleep_interval = sleep_interval
113+
self.truncate_browse_response_length = truncate_browse_response_length
114+
self.tokenizer = (
115+
AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) if tokenizer_path else None
116+
)
117+
if self.truncate_browse_response_length is not None and self.tokenizer is None:
118+
warnings.warn(
119+
'truncate_browse_response_length is set but tokenizer_path is not provided. '
120+
'The raw webpage content will be truncated by characters instead of tokens.'
121+
)
122+
123+
async def run(self, url: str | List[str], goal: str) -> str:
124+
if isinstance(url, str):
125+
url = [url]
126+
127+
async def _inner_call(single_url: str) -> str:
128+
try:
129+
return await self._read_webpage(single_url, goal)
130+
except Exception as e:
131+
return f"Error fetching {single_url}: {str(e)}"
132+
133+
response = await asyncio.gather(*[_inner_call(single_url) for single_url in url])
134+
return "\n=======\n".join(response).strip()
135+
136+
async def _read_webpage(self, url: str, goal: str) -> str:
137+
tool_response = compressed = None
138+
return_template = (
139+
f"The useful information in {url} for user goal {goal} as follows: \n\n"
140+
f"Evidence in page: \n{{evidence}}\n\nSummary: \n{{summary}}\n\n"
141+
)
142+
for _ in range(self.max_browse_attempts):
143+
resp = await self.browse_tool({'url': url})
144+
if resp.valid == ActionValidCode.OPEN and resp.state == ActionStatusCode.SUCCESS:
145+
tool_response = resp.format_result()
146+
break
147+
await asyncio.sleep(self.sleep_interval)
148+
else:
149+
return return_template.format(
150+
evidence="The provided webpage content could not be accessed. Please check the URL or file format.",
151+
summary="The webpage content could not be processed, and therefore, no information is available.",
152+
)
153+
154+
if self.truncate_browse_response_length is not None:
155+
tool_response = (
156+
self.tokenizer.decode(
157+
self.tokenizer.encode(
158+
tool_response,
159+
max_length=self.truncate_browse_response_length,
160+
truncation=True,
161+
add_special_tokens=False,
162+
)
163+
)
164+
if self.tokenizer is not None
165+
else tool_response[: self.truncate_browse_response_length]
166+
)
167+
168+
for _ in range(self.max_extract_attempts):
169+
try:
170+
prompt = self.EXTRACTION_PROMPT.format(webpage_content=tool_response, goal=goal)
171+
llm_response = await self.llm.chat([{'role': 'user', 'content': prompt}])
172+
if llm_response and not isinstance(llm_response, str):
173+
llm_response = (
174+
llm_response.content
175+
if isinstance(llm_response, AgentMessage)
176+
else llm_response.choices[0].message.content
177+
)
178+
if not llm_response or len(llm_response) < 10:
179+
tool_response = tool_response[: int(len(tool_response) * 0.7)]
180+
continue
181+
compressed = extract_last_json(llm_response)
182+
if isinstance(compressed, dict) and all(
183+
key in compressed for key in ['rational', 'evidence', 'summary']
184+
):
185+
break
186+
except Exception:
187+
print(f"Error in extracting information: {traceback.format_exc()}")
188+
await asyncio.sleep(self.sleep_interval)
189+
else:
190+
return return_template.format(
191+
evidence="Failed to extract relevant information from the webpage content.",
192+
summary="The webpage content could not be processed, and therefore, no information is available.",
193+
)
194+
return return_template.format(evidence=compressed['evidence'], summary=compressed['summary'])

lagent/agents/agent.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
for hook in hooks:
5959
hook = create_object(hook)
6060
self.register_hook(hook)
61+
self._sessions_to_scroll = set()
6162

6263
def update_memory(self, message, session_id=0):
6364
if self.memory:
@@ -70,10 +71,28 @@ def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> AgentMessa
7071
result = hook.before_agent(self, message, session_id)
7172
if result:
7273
message = result
73-
self.update_memory(message, session_id=session_id)
74+
75+
# resume aborted rollout
76+
_message = self._scroll_buffer(message[-1], session_id)
77+
if _message is not None:
78+
if _message.finish_reason != 'abort':
79+
_message = copy.deepcopy(_message)
80+
for hook in self._hooks.values():
81+
result = hook.after_agent(self, _message, session_id)
82+
if result:
83+
_message = result
84+
return _message
85+
message[-1].extra_info['partial_response'] = _message
86+
else:
87+
self.update_memory(message, session_id=session_id)
7488
response_message = self.forward(*message, session_id=session_id, **kwargs)
89+
if _message and _message.finish_reason == 'abort':
90+
message[-1].extra_info.pop('partial_response', None)
7591
if not isinstance(response_message, AgentMessage):
76-
response_message = AgentMessage(sender=self.name, content=response_message)
92+
if isinstance(response_message, str):
93+
response_message = AgentMessage(sender=self.name, content=response_message)
94+
else:
95+
response_message = AgentMessage.from_model_response(response_message, self.name)
7796
self.update_memory(response_message, session_id=session_id)
7897
response_message = copy.deepcopy(response_message)
7998
for hook in self._hooks.values():
@@ -158,6 +177,63 @@ def reset(self, session_id=0, keypath: Optional[str] = None, recursive: bool = F
158177
for agent in getattr(self, '_agents', {}).values():
159178
agent.reset(session_id, recursive=True)
160179

180+
def get_messages(self, session_id=0, keypath: Optional[str] = None) -> List[dict]:
181+
"""Get OpenAI format messages from memory.
182+
183+
Args:
184+
session_id (int): The session id of the memory.
185+
keypath (Optional[str]): The keypath of the sub-agent to get messages from. Default is None.
186+
187+
Returns:
188+
List[dict]: The messages from the memory including the sub-agent's system prompt.
189+
"""
190+
if keypath:
191+
keys, agent = keypath.split('.'), self
192+
for key in keys:
193+
agents = getattr(agent, '_agents', {})
194+
if key not in agents:
195+
raise KeyError(f'No sub-agent named {key} in {agent}')
196+
agent = agents[key]
197+
return agent.get_messages(session_id=session_id)
198+
if self.aggregator:
199+
return self.aggregator.aggregate(self.memory.get(session_id), self.name, self.output_format, self.template)
200+
raise ValueError(f'{self.name} has no aggregator to get messages')
201+
202+
def _scroll_buffer(self, message, session_id, hash_func=lambda m: m.uid):
203+
memory = self.memory and self.memory.get(session_id)
204+
if not memory:
205+
return
206+
mem = self.memory.get_memory(session_id)
207+
finish_reasons = [m.finish_reason for m in mem]
208+
if not ('abort' in finish_reasons or session_id in self._sessions_to_scroll):
209+
return
210+
if session_id not in self._sessions_to_scroll:
211+
self._enable_scroll_mode(session_id, recursive=True)
212+
aborted_msg_idx = finish_reasons.index('abort') if 'abort' in finish_reasons else len(mem) - 1
213+
memory.delete(range(aborted_msg_idx + 1, len(mem)))
214+
enc = hash_func(message)
215+
for i in range(0, aborted_msg_idx):
216+
if hash_func(mem[i]) == enc:
217+
ret = mem[i + 1]
218+
if i + 1 == aborted_msg_idx:
219+
if ret.finish_reason == 'abort':
220+
memory.delete(aborted_msg_idx)
221+
self._disable_scroll_mode(session_id)
222+
return ret
223+
self._disable_scroll_mode(session_id, recursive=True)
224+
225+
def _enable_scroll_mode(self, session_id, recursive=False):
226+
self._sessions_to_scroll.add(session_id)
227+
if recursive:
228+
for sub_agent in getattr(self, '_agents', {}).values():
229+
sub_agent._enable_scroll_mode(session_id, True)
230+
231+
def _disable_scroll_mode(self, session_id, recursive=False):
232+
self._sessions_to_scroll.discard(session_id)
233+
if recursive:
234+
for sub_agent in getattr(self, '_agents', {}).values():
235+
sub_agent._disable_scroll_mode(session_id, True)
236+
161237
def __repr__(self):
162238

163239
def _rcsv_repr(agent, n_indent=1):
@@ -183,10 +259,28 @@ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs) -> Agen
183259
result = hook.before_agent(self, message, session_id)
184260
if result:
185261
message = result
186-
self.update_memory(message, session_id=session_id)
262+
263+
# resume aborted rollout
264+
_message = self._scroll_buffer(message[-1], session_id)
265+
if _message is not None:
266+
if _message.finish_reason != 'abort':
267+
_message = copy.deepcopy(_message)
268+
for hook in self._hooks.values():
269+
result = hook.after_agent(self, _message, session_id)
270+
if result:
271+
_message = result
272+
return _message
273+
message[-1].extra_info['partial_response'] = _message
274+
else:
275+
self.update_memory(message, session_id=session_id)
187276
response_message = await self.forward(*message, session_id=session_id, **kwargs)
277+
if _message and _message.finish_reason == 'abort':
278+
message[-1].extra_info.pop('partial_response', None)
188279
if not isinstance(response_message, AgentMessage):
189-
response_message = AgentMessage(sender=self.name, content=response_message)
280+
if isinstance(response_message, str):
281+
response_message = AgentMessage(sender=self.name, content=response_message)
282+
else:
283+
response_message = AgentMessage.from_model_response(response_message, self.name)
190284
self.update_memory(response_message, session_id=session_id)
191285
response_message = copy.deepcopy(response_message)
192286
for hook in self._hooks.values():
Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,42 @@
1-
from typing import Dict, List
1+
from typing import List
22

33
from lagent.memory import Memory
44
from lagent.prompts import StrParser
5+
from lagent.schema import ActionReturn
56

67

78
class DefaultAggregator:
89

9-
def aggregate(self,
10-
messages: Memory,
11-
name: str,
12-
parser: StrParser = None,
13-
system_instruction: str = None) -> List[Dict[str, str]]:
10+
def aggregate(self, messages: Memory, name: str, parser: StrParser = None, system_instruction=None) -> List[dict]:
1411
_message = []
1512
messages = messages.get_memory()
1613
if system_instruction:
17-
_message.extend(
18-
self.aggregate_system_intruction(system_instruction))
14+
_message.extend(self.aggregate_system_intruction(system_instruction))
1915
for message in messages:
2016
if message.sender == name:
21-
_message.append(
22-
dict(role='assistant', content=str(message.content)))
17+
_message.append(message.to_model_request())
2318
else:
24-
user_message = message.content
25-
if len(_message) > 0 and _message[-1]['role'] == 'user':
26-
_message[-1]['content'] += user_message
19+
user_message, extra_info = message.content, message.extra_info
20+
if isinstance(user_message, list):
21+
for m in user_message:
22+
if isinstance(m, dict):
23+
m = ActionReturn(**m)
24+
assert isinstance(m, ActionReturn), f"Expected m to be ActionReturn, but got {type(m)}"
25+
_message.append(
26+
dict(
27+
role='tool',
28+
tool_call_id=m.tool_call_id,
29+
content=m.format_result(),
30+
name=m.type,
31+
extra_info=extra_info,
32+
)
33+
)
2734
else:
28-
_message.append(dict(role='user', content=user_message))
35+
if len(_message) > 0 and _message[-1]['role'] == 'user':
36+
_message[-1]['content'] += user_message
37+
_message[-1]['extra_info'] = extra_info
38+
else:
39+
_message.append(dict(role='user', content=user_message, extra_info=extra_info))
2940
return _message
3041

3142
@staticmethod
@@ -39,6 +50,5 @@ def aggregate_system_intruction(system_intruction) -> List[dict]:
3950
if not isinstance(msg, dict):
4051
raise TypeError(f'Unsupported message type: {type(msg)}')
4152
if not ('role' in msg and 'content' in msg):
42-
raise KeyError(
43-
f"Missing required key 'role' or 'content': {msg}")
53+
raise KeyError(f"Missing required key 'role' or 'content': {msg}")
4454
return system_intruction

0 commit comments

Comments
 (0)