Skip to content

Commit cccc13f

Browse files
committed
feat: add message format converter and improve tool call handling
- Add msg_converter.py for bidirectional OpenAI<->AgentScope format conversion - Support tool_call_id in basic_tracker context serialization - Update multiagent_tracking to use msg_converter utilities - Update schema comments to English - Improve workflow_metadata documentation in base_tracker
1 parent 697bc5b commit cccc13f

File tree

5 files changed

+402
-19
lines changed

5 files changed

+402
-19
lines changed

ajet/context_tracker/base_tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(self, config, tokenizer, **kwargs):
136136
self.already_mad_flag: bool = False
137137
self.round_cnt = 0
138138
self.generation_prompt_token = None
139-
self.workflow_metadata: Optional[Dict[str, Any]] = None # 初始化 workflow_metadata 以存储工具统计信息
139+
self.workflow_metadata: Optional[Dict[str, Any]] = None # Initialize workflow_metadata to store tool statistics
140140

141141
assert (
142142
self.config.ajet.data.max_prompt_length

ajet/context_tracker/basic_tracker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List:
192192
}
193193
if ext_msg.tool_calls:
194194
d.update({"tool_calls": ext_msg.tool_calls})
195+
if ext_msg.tool_call_id:
196+
d.update({"tool_call_id": ext_msg.tool_call_id})
195197
result.append(d)
196198
return result
197199

ajet/context_tracker/multiagent_tracking.py

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@
2020
from ajet.utils.color_hsl import adjust_color_hsl
2121
from ajet.utils.compute_madness import compute_string_madness
2222
from ajet.utils.tokenizer import ajet_apply_chat_template
23-
23+
#
24+
from ajet.utils.msg_converter import (
25+
convert_grouped_steps_to_openai_format,
26+
convert_ext_msg_to_openai_format,
27+
agentscope_to_openai,
28+
openai_to_agentscope,
29+
agentscope_to_openai_grouped,
30+
openai_to_agentscope_grouped,
31+
)
2432
@dataclass
2533
class TimelineMergingPolicyConfig:
2634
timeline_compare_level: str = "text"
@@ -101,27 +109,43 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
101109
author = "env"
102110
ignore = False
103111
str_content = ""
104-
105-
# fix msg content
106-
if msg["content"] is None:
107-
msg["content"] = ""
108-
elif isinstance(msg["content"], list):
109-
for item in msg["content"]:
110-
if "text" not in item:
111-
logger.warning(
112-
f"Non-text content in message content detected: {item}. Ignoring."
113-
)
114-
ignore = True
115-
break
112+
extracted_tool_call_id = ""
113+
for item_idx, item in enumerate(msg["content"]):
114+
if isinstance(item, dict) and item.get("type") == "tool_result":
115+
is_tool_result_msg = True # 标记为 tool_result 消息
116+
# Extract tool_call_id from the tool_result block
117+
if item.get("id"):
118+
extracted_tool_call_id = item.get("id", "")
119+
output = item.get("output", "")
120+
if isinstance(output, str):
121+
str_content += output
122+
elif isinstance(output, list):
123+
# output can be List[TextBlock | ImageBlock | AudioBlock]
124+
for out_item in output:
125+
if isinstance(out_item, str):
126+
str_content += out_item
127+
elif isinstance(out_item, dict) and "text" in out_item:
128+
str_content += str(out_item["text"])
129+
else:
130+
str_content += str(output)
131+
elif isinstance(item, dict) and "text" in item:
116132
if isinstance(item["text"], str):
117133
str_content += str(item["text"])
118134
else:
119135
str_content = ""
120-
msg["content"] = str_content
121-
else:
122-
raise ValueError(
123-
f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}"
124-
)
136+
else:
137+
logger.warning(
138+
f"Non-text content in message content detected: {item}. Ignoring."
139+
)
140+
ignore = True
141+
break
142+
msg["content"] = str_content
143+
msg["tool_call_id"] = extracted_tool_call_id # Store extracted tool_call_id
144+
145+
# ★ 关键修复:如果是 tool_result 消息,将 role 恢复为 "tool"(OpenAI 格式)
146+
if is_tool_result_msg and extracted_tool_call_id:
147+
msg["role"] = "tool"
148+
125149

126150
if ignore:
127151
continue
@@ -143,6 +167,7 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
143167
tokenizer=self.tokenizer,
144168
tools=tools,
145169
tool_calls=(msg["tool_calls"] if "tool_calls" in msg else []),
170+
tool_call_id=(msg["tool_call_id"] if "tool_call_id" in msg else ""),
146171
token_generator="auto",
147172
first_message=(i == 0),
148173
)
@@ -580,3 +605,25 @@ def check_context_token_num_safe(
580605
else:
581606
ret = (False, token_overflow, "token_overflow")
582607
return ret
608+
609+
def get_grouped_steps_openai_format(self) -> List[List[Dict[str, Any]]]:
610+
"""
611+
将 grouped_steps 转换为 OpenAI 格式并返回。
612+
613+
Returns:
614+
OpenAI 格式的轨迹数据 (List of List of dict)
615+
每条消息格式如:
616+
- {"role": "assistant", "content": "...", "tool_calls": [...]}
617+
- {"role": "tool", "content": "...", "tool_call_id": "call_xxx"}
618+
- {"role": "user/system", "content": "..."}
619+
"""
620+
return convert_grouped_steps_to_openai_format(self.grouped_steps)
621+
622+
def get_full_context_openai_format(self) -> List[Dict[str, Any]]:
623+
"""
624+
将当前 full_context 转换为 OpenAI 格式并返回。
625+
626+
Returns:
627+
OpenAI 格式的消息列表 (List of dict)
628+
"""
629+
return [convert_ext_msg_to_openai_format(msg) for msg in self.full_context]

ajet/schema/extended_msg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
build_from_uuid="",
7373
tools=[],
7474
tool_calls=[],
75+
tool_call_id="",
7576
token_logprob_arr=[],
7677
first_message=False,
7778
):
@@ -88,6 +89,7 @@ def __init__(
8889
self.clip = clip
8990
self.tools = tools
9091
self.tool_calls = tool_calls
92+
self.tool_call_id = tool_call_id
9193
self.uuid = uuid.uuid4().hex
9294
self.build_from_uuid = build_from_uuid
9395
self.first_message = first_message
@@ -143,6 +145,8 @@ def auto_tokenize_non_first_message(self, tokenizer, tools):
143145
}
144146
if self.tool_calls:
145147
auto_tokenize_target.update({"tool_calls": self.tool_calls})
148+
if self.tool_call_id:
149+
auto_tokenize_target.update({"tool_call_id": self.tool_call_id})
146150
text_frag_to = ajet_apply_chat_template(
147151
tokenizer=tokenizer,
148152
conversation=DUMMY_MSG + [auto_tokenize_target],

0 commit comments

Comments
 (0)