Skip to content

Commit 0553aaf

Browse files
authored
get message return tools (#315)
1 parent a77ec89 commit 0553aaf

3 files changed

Lines changed: 47 additions & 33 deletions

File tree

lagent/agents/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,9 @@ def get_messages(self, prefix='', destination=None) -> Dict[str, List[dict]]:
183183
if self.aggregator:
184184
messages = self.aggregator.aggregate(self.memory, self.name, self.output_format, self.template)
185185
if isinstance(messages, tuple):
186-
messages, _ = messages
186+
messages, tools = messages
187187
destination[prefix + 'messages'] = messages
188+
destination[prefix + 'tools'] = tools
188189
for name, agent in getattr(self, '_agents', {}).items():
189190
if isinstance(agent, Agent):
190191
agent.get_messages(destination=destination, prefix=prefix + name + '.')

lagent/agents/internclaw_agent.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,24 @@ async def forward(self, message, **kwargs):
128128
)
129129
async def _inner_func(tool_call):
130130
tool_call = deepcopy(tool_call)
131+
tool_name = tool_call['function'].get('name')
131132
try:
132-
if tool_call['function']['name'].split('.', 1)[0] not in self.actions:
133+
if tool_name.split('.', 1)[0] not in self.actions:
133134
return ActionReturn(
134-
valid=ActionValidCode.INVALID, errmsg=f"Tool {tool_call['function']['name']} Not Found"
135+
type=tool_name,
136+
args=tool_call['function'].get('arguments'),
137+
valid=ActionValidCode.INVALID,
138+
errmsg=f"Tool {tool_name} Not Found",
135139
)
136140
if isinstance(tool_call['function']['arguments'], str):
137141
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
138142
except Exception as e:
139-
return ActionReturn(valid=ActionValidCode.INVALID, errmsg=str(e))
143+
return ActionReturn(
144+
type=tool_name,
145+
args=tool_call['function'].get('arguments'),
146+
valid=ActionValidCode.INVALID,
147+
errmsg=str(e),
148+
)
140149
tool_response: ActionReturn = (
141150
await self.actions(
142151
AgentMessage(
@@ -151,30 +160,29 @@ async def _inner_func(tool_call):
151160

152161
tasks = [_inner_func(tool_call) for tool_call in message.tool_calls]
153162
responses = await asyncio.gather(*tasks)
154-
for i, resp in enumerate(responses):
155-
if resp.valid != ActionValidCode.OPEN:
156-
return AgentMessage(
157-
sender=self.name,
158-
content=f'Tool Call Error: {resp.errmsg} in tool call '
159-
f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}',
160-
)
161-
if resp.state != ActionStatusCode.SUCCESS:
162-
return AgentMessage(
163-
sender=self.name,
164-
content=f'Tool Call Error: {resp.errmsg} in tool call '
165-
f'{json.dumps(message.tool_calls[i], ensure_ascii=False)}',
166-
reward=-1 if resp.state == ActionStatusCode.ARGS_ERROR else 0,
167-
)
168163
# Pair each ActionReturn with its tool_call_id for proper LLM API formatting
169164
tool_results = []
170-
for tc, r in zip(message.tool_calls, responses):
171-
result_dict = asdict(r)
165+
reward = 0.0
166+
for tc, resp in zip(message.tool_calls, responses):
167+
result_dict = asdict(resp)
172168
result_dict['tool_call_id'] = tc.get('id', '')
169+
if resp.valid != ActionValidCode.OPEN:
170+
result_dict['errmsg'] = (
171+
f'Tool Call Error: {resp.errmsg} in tool call '
172+
f'{json.dumps(tc, ensure_ascii=False)}'
173+
)
174+
elif resp.state != ActionStatusCode.SUCCESS:
175+
result_dict['errmsg'] = (
176+
f'Tool Call Error: {resp.errmsg} in tool call '
177+
f'{json.dumps(tc, ensure_ascii=False)}'
178+
)
179+
if resp.state == ActionStatusCode.ARGS_ERROR:
180+
reward = -1
173181
tool_results.append(result_dict)
174182
return_message = AgentMessage(
175183
sender=self.name,
176184
content=tool_results,
177-
reward=0.0,
185+
reward=reward,
178186
env_info=await self.get_env_info(),
179187
)
180188

lagent/utils/util.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
from functools import partial
1111
from logging.handlers import RotatingFileHandler
12-
from typing import Any, Dict, Generator, Iterable, List, Optional, Union, cast
12+
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
1313

1414

1515
def load_class_from_string(class_path: str, path=None):
@@ -29,29 +29,34 @@ def load_class_from_string(class_path: str, path=None):
2929
sys.path.remove(path)
3030

3131

32+
def _is_ray_actor_class(obj_type) -> bool:
33+
try:
34+
from ray.actor import ActorClass
35+
except ImportError:
36+
return False
37+
return isinstance(obj_type, ActorClass)
38+
39+
3240
def create_object(config: Union[Dict, Any] = None):
3341
"""Create an instance based on the configuration where 'type' is a
3442
preserved key to indicate the class (path). When accepting non-dictionary
3543
input, the function degenerates to an identity.
3644
"""
37-
from ray.actor import ActorClass
38-
3945
if config is None or not isinstance(config, dict):
4046
return config
41-
assert isinstance(config, dict) and 'type' in config
47+
assert 'type' in config
4248

4349
config = config.copy()
4450
obj_type = config.pop('type')
4551
if isinstance(obj_type, str):
4652
obj_type = load_class_from_string(obj_type)
47-
if isinstance(obj_type, ActorClass):
48-
obj = cast(ActorClass, obj_type).remote(**config)
49-
elif inspect.isclass(obj_type):
50-
obj = obj_type(**config)
51-
else:
52-
assert callable(obj_type)
53-
obj = partial(obj_type, **config)
54-
return obj
53+
54+
if _is_ray_actor_class(obj_type):
55+
return obj_type.remote(**config)
56+
if inspect.isclass(obj_type):
57+
return obj_type(**config)
58+
assert callable(obj_type)
59+
return partial(obj_type, **config)
5560

5661

5762
async def async_as_completed(futures: Iterable[asyncio.Future]):

0 commit comments

Comments
 (0)