Skip to content

Commit 836e380

Browse files
committed
refactor: clean up unused imports and add text extraction method in MultiAgentContextTracker; add skip GPU check option in launcher
1 parent b548324 commit 836e380

3 files changed

Lines changed: 59 additions & 74 deletions

File tree

ajet/context_tracker/multiagent_tracking.py

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,7 @@
2020
from ajet.utils.color_hsl import adjust_color_hsl
2121
from ajet.utils.compute_madness import compute_string_madness
2222
from ajet.utils.tokenizer import ajet_apply_chat_template
23-
#
24-
from ajet.utils.msg_converter import (
25-
convert_grouped_steps_to_openai_format,
26-
convert_ext_msg_to_openai_format,
27-
agentscope_to_openai,
28-
openai_to_agentscope,
29-
agentscope_to_openai_grouped,
30-
openai_to_agentscope_grouped,
31-
)
23+
3224
@dataclass
3325
class TimelineMergingPolicyConfig:
3426
timeline_compare_level: str = "text"
@@ -82,6 +74,40 @@ def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool
8274
tools[i]["function"]["parameters"] = tools[i]["function"].pop("parameters")
8375
return tools
8476

77+
def extract_text_content_from_content_dict(self, msg):
78+
# msg = {
79+
# "role": "assistant",
80+
# "content": [
81+
# {
82+
# "type": "text",
83+
# "text": "some text"
84+
# },
85+
# ],
86+
# }
87+
88+
str_content = ""
89+
for item in msg["content"]:
90+
# item = {
91+
# "type": "text",
92+
# "text": "some text"
93+
# },
94+
95+
assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}"
96+
97+
if ("text" not in item):
98+
logger.warning(
99+
f"Non-text content in message content detected: {item}. Ignoring."
100+
)
101+
should_skip_message = True
102+
return str_content, should_skip_message
103+
104+
if isinstance(item["text"], str):
105+
str_content += str(item["text"])
106+
else:
107+
str_content = ""
108+
109+
should_skip_message = False
110+
return str_content, should_skip_message
85111

86112
def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_toolcalls: bool = False) -> List[ExtendedMessage]:
87113
"""Spawn a timeline from messages.
@@ -101,55 +127,32 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
101127
consider_roles.remove("tool")
102128

103129
for i, msg in enumerate(messages):
130+
104131
if (disable_toolcalls) and (not isinstance(msg["content"], str)):
105132
continue
133+
106134
if msg["role"] not in consider_roles:
107135
continue
136+
108137
if not isinstance(msg["content"], str):
109138
author = "env"
110-
ignore = False
111-
str_content = ""
112-
extracted_tool_call_id = ""
113-
for item_idx, item in enumerate(msg["content"]):
114-
if isinstance(item, dict) and item.get("type") == "tool_result":
115-
is_tool_result_msg = True # 标记为 tool_result 消息
116-
# Extract tool_call_id from the tool_result block
117-
if item.get("id"):
118-
extracted_tool_call_id = item.get("id", "")
119-
output = item.get("output", "")
120-
if isinstance(output, str):
121-
str_content += output
122-
elif isinstance(output, list):
123-
# output can be List[TextBlock | ImageBlock | AudioBlock]
124-
for out_item in output:
125-
if isinstance(out_item, str):
126-
str_content += out_item
127-
elif isinstance(out_item, dict) and "text" in out_item:
128-
str_content += str(out_item["text"])
129-
else:
130-
str_content += str(output)
131-
elif isinstance(item, dict) and "text" in item:
132-
if isinstance(item["text"], str):
133-
str_content += str(item["text"])
134-
else:
135-
str_content = ""
136-
else:
137-
logger.warning(
138-
f"Non-text content in message content detected: {item}. Ignoring."
139-
)
140-
ignore = True
141-
break
142-
msg["content"] = str_content
143-
msg["tool_call_id"] = extracted_tool_call_id # Store extracted tool_call_id
144-
145-
# ★ 关键修复:如果是 tool_result 消息,将 role 恢复为 "tool"(OpenAI 格式)
146-
if is_tool_result_msg and extracted_tool_call_id:
147-
msg["role"] = "tool"
148-
149-
150-
if ignore:
139+
should_skip_message = False
140+
141+
# fix msg content
142+
if msg["content"] is None:
143+
msg["content"] = ""
144+
145+
elif isinstance(msg["content"], list):
146+
msg["content"], should_skip_message = self.extract_text_content_from_content_dict(msg)
147+
148+
else:
149+
raise ValueError(f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}")
150+
151+
if should_skip_message:
151152
continue
152-
msg["content"] = str(msg["content"]) # TODO: better handling mm data
153+
154+
if not isinstance(msg["content"], str):
155+
msg["content"] = str(msg["content"]) # TODO: better handling mm data
153156

154157
if msg["role"] == "system":
155158
author = "initialization"
@@ -169,6 +172,7 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
169172
tool_calls=(msg["tool_calls"] if "tool_calls" in msg else []),
170173
tool_call_id=(msg["tool_call_id"] if "tool_call_id" in msg else ""),
171174
token_generator="auto",
175+
name = (msg["name"] if "name" in msg else ""),
172176
first_message=(i == 0),
173177
)
174178
]
@@ -605,25 +609,3 @@ def check_context_token_num_safe(
605609
else:
606610
ret = (False, token_overflow, "token_overflow")
607611
return ret
608-
609-
def get_grouped_steps_openai_format(self) -> List[List[Dict[str, Any]]]:
610-
"""
611-
将 grouped_steps 转换为 OpenAI 格式并返回。
612-
613-
Returns:
614-
OpenAI 格式的轨迹数据 (List of List of dict)
615-
每条消息格式如:
616-
- {"role": "assistant", "content": "...", "tool_calls": [...]}
617-
- {"role": "tool", "content": "...", "tool_call_id": "call_xxx"}
618-
- {"role": "user/system", "content": "..."}
619-
"""
620-
return convert_grouped_steps_to_openai_format(self.grouped_steps)
621-
622-
def get_full_context_openai_format(self) -> List[Dict[str, Any]]:
623-
"""
624-
将当前 full_context 转换为 OpenAI 格式并返回。
625-
626-
Returns:
627-
OpenAI 格式的消息列表 (List of dict)
628-
"""
629-
return [convert_ext_msg_to_openai_format(msg) for msg in self.full_context]

ajet/launcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def parse_args():
8686
help="Launch Crafters Env Simulation",
8787
)
8888
parser.add_argument("--reboot", action="store_true", default=False, help="reboot flag")
89+
parser.add_argument("--skip-check-avail-gpu", action="store_true", default=False, help="Skip GPU availability check")
8990
parser.add_argument(
9091
"--kill",
9192
type=str,
@@ -305,7 +306,7 @@ def main():
305306

306307
if args.with_finworld:
307308
pty_launch("finworld")
308-
309+
309310
if args.with_crafters:
310311
pty_launch("crafters")
311312

ajet/schema/extended_msg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
tool_calls=[],
7575
tool_call_id="",
7676
token_logprob_arr=[],
77+
name="", # preserved field, not used currently
7778
first_message=False,
7879
):
7980
self.author = author
@@ -90,6 +91,7 @@ def __init__(
9091
self.tools = tools
9192
self.tool_calls = tool_calls
9293
self.tool_call_id = tool_call_id
94+
self.name = name # preserved field, not used currently
9395
if not isinstance(self.tool_calls, list):
9496
# agent scope sometimes gives weird type for tool_calls, which is against OpenAI schema
9597
self.tool_calls = list(self.tool_calls)

0 commit comments

Comments
 (0)