Skip to content

Commit a77ec89

Browse files
Refactor get_message interface (#314)
refactor `get_message` interface
1 parent ebf4ef6 commit a77ec89

2 files changed

Lines changed: 39 additions & 26 deletions

File tree

lagent/agents/agent.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -168,26 +168,27 @@ def reset(self, keypath: Optional[str] = None, recursive: bool = False):
168168
for agent in getattr(self, '_agents', {}).values():
169169
agent.reset(recursive=True)
170170

171-
def get_messages(self, keypath: Optional[str] = None) -> List[dict]:
172-
"""Get OpenAI format messages from memory.
171+
def get_messages(self, prefix='', destination=None) -> Dict[str, List[dict]]:
172+
"""Get OpenAI format messages from all agents recursively, similar to state_dict.
173173
174174
Args:
175-
keypath (Optional[str]): The keypath of the sub-agent to get messages from. Default is None.
175+
prefix (str): The prefix to prepend to the key. Default is ''.
176+
destination (Optional[Dict]): The destination dict to store the messages. Default is None.
176177
177178
Returns:
178-
List[dict]: The messages from the memory including the sub-agent's system prompt.
179+
Dict[str, List[dict]]: A dict mapping agent keypaths to their OpenAI format message lists.
179180
"""
180-
if keypath:
181-
keys, agent = keypath.split('.'), self
182-
for key in keys:
183-
agents = getattr(agent, '_agents', {})
184-
if key not in agents:
185-
raise KeyError(f'No sub-agent named {key} in {agent}')
186-
agent = agents[key]
187-
return agent.get_messages()
181+
if destination is None:
182+
destination = {}
188183
if self.aggregator:
189-
return self.aggregator.aggregate(self.memory, self.name, self.output_format, self.template)
190-
raise ValueError(f'{self.name} has no aggregator to get messages')
184+
messages = self.aggregator.aggregate(self.memory, self.name, self.output_format, self.template)
185+
if isinstance(messages, tuple):
186+
messages, _ = messages
187+
destination[prefix + 'messages'] = messages
188+
for name, agent in getattr(self, '_agents', {}).items():
189+
if isinstance(agent, Agent):
190+
agent.get_messages(destination=destination, prefix=prefix + name + '.')
191+
return destination
191192

192193
def _scroll_buffer(self, message, hash_func=lambda m: m.uid):
193194
if not self.memory:

lagent/serving/sandbox/daemon.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"""
2929

3030
from __future__ import annotations
31-
3231
import argparse
3332
import asyncio
3433
import inspect
@@ -96,9 +95,7 @@ async def start(self) -> None:
9695
"""Start listening. Removes stale socket file if present."""
9796
if os.path.exists(self.sock_path):
9897
os.unlink(self.sock_path)
99-
self._server = await asyncio.start_unix_server(
100-
self._handle_client, path=self.sock_path
101-
)
98+
self._server = await asyncio.start_unix_server(self._handle_client, path=self.sock_path)
10299
os.chmod(self.sock_path, 0o777)
103100
logger.info("%s listening on %s", self.__class__.__name__, self.sock_path)
104101
await self._server.serve_forever()
@@ -139,10 +136,12 @@ async def _dispatch(self, request: dict) -> dict:
139136
if cmd == "ping":
140137
return {"status": "ok", "type": self.daemon_type}
141138
if cmd == "shutdown":
139+
142140
async def _delayed_close():
143141
await asyncio.sleep(0.1)
144142
if self._server:
145143
self._server.close()
144+
146145
asyncio.create_task(_delayed_close())
147146
return {"status": "shutting_down"}
148147
return {"error": f"Unknown command: {cmd}"}
@@ -197,18 +196,22 @@ async def _dispatch(self, request: dict) -> dict:
197196
name = request.get("name")
198197
parameters = request.get("parameters", {})
199198
if not name:
200-
return dataclass2dict(ActionReturn(
201-
errmsg="Missing 'name' in request",
202-
state=ActionStatusCode.ARGS_ERROR,
203-
))
199+
return dataclass2dict(
200+
ActionReturn(
201+
errmsg="Missing 'name' in request",
202+
state=ActionStatusCode.ARGS_ERROR,
203+
)
204+
)
204205

205206
try:
206207
action_return = await self.executor.forward(name, parameters)
207208
except Exception as e:
208209
logger.exception("Action %s failed", name)
209210
action_return = ActionReturn(
210-
args=parameters, type=name,
211-
errmsg=str(e), state=ActionStatusCode.API_ERROR,
211+
args=parameters,
212+
type=name,
213+
errmsg=str(e),
214+
state=ActionStatusCode.API_ERROR,
212215
)
213216
return dataclass2dict(action_return)
214217

@@ -349,6 +352,12 @@ async def _dispatch(self, request: dict) -> dict:
349352
except Exception as e:
350353
return {"error": str(e)}
351354

355+
if cmd == 'get_messages':
356+
try:
357+
return self.agent.get_messages()
358+
except Exception as e:
359+
return {"error": str(e)}
360+
352361
return {"error": f"Unknown command: {cmd}"}
353362

354363
@staticmethod
@@ -404,11 +413,14 @@ def main():
404413
# -- start --
405414
p_start = sub.add_parser("start", help="Start the daemon")
406415
p_start.add_argument(
407-
"--sock", default="/tmp/lagent_action.sock",
416+
"--sock",
417+
default="/tmp/lagent_action.sock",
408418
help="Unix socket path",
409419
)
410420
p_start.add_argument(
411-
"--mode", choices=["actions", "agent"], default="actions",
421+
"--mode",
422+
choices=["actions", "agent"],
423+
default="actions",
412424
help="'actions' for Level 1 (ActionDaemon), 'agent' for Level 2 (AgentDaemon)",
413425
)
414426
p_start.add_argument(

0 commit comments

Comments
 (0)