Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from collections import defaultdict
from pprint import pprint
from typing import List, Optional
from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_cmts
from loguru import logger as loguru_logger
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import of loguru_logger is redundant as logger is already imported from loguru on line 27 and this alias is not used. Please remove it to keep the imports clean.


import hydra
import numpy as np
Expand Down Expand Up @@ -54,7 +56,14 @@
from ajet.schema.task import Task
from ajet.task_reader import dict_to_ajet_task
from ajet.task_rollout.native_parallel_worker import VerlRolloutManager

from ajet.utils.save_trajectory import save_train_trajectory, save_eval_trajectory
from ajet.utils.msg_converter import (
convert_grouped_steps_to_openai_format,
convert_ext_msg_to_openai_format,
agentscope_to_openai,
openai_to_agentscope,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These functions from msg_converter are imported but appear to be unused in this file. Please remove the unused imports to improve code clarity.

)
from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_cmts

def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor:
"""
Expand Down Expand Up @@ -577,6 +586,9 @@ def fit(self): # noqa: C901
tasks, mode="sample", epoch=f"train.{epoch}"
)
logger.info("=" * 10 + "end fit rollout" + "=" * 10)

if self.config.ajet.trainer_common.save_trajectory:
save_train_trajectory(context_tracker_arr, self.global_steps)
logger.info("begin to convert context_tracker_arr to dataproto")
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
logger.info("end convertion")
Expand All @@ -602,6 +614,14 @@ def fit(self): # noqa: C901
),
}
)
from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_trajectories
from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_trajectories
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should be at the top of the file, not inside a function. This improves readability and avoids re-importing modules. Please move these imports to the top of the file.

tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr)
reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr)
if tool_metrics:
metrics.update(tool_metrics)
if reward_metrics:
metrics.update(reward_metrics)
if self.config.ajet.execute_test: # apply a test probe
from swanlab.data.run.main import get_run

Expand Down Expand Up @@ -1029,6 +1049,10 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
for ctx_tracker in ctx_trackers:
ctx_tracker.generate_log()

# save eval trajectories
if self.config.ajet.trainer_common.save_trajectory:
save_eval_trajectory(ctx_trackers, self.global_steps)

rewards = [ctx_tracker.reward_structure.raw_reward for ctx_tracker in ctx_trackers]
num_tasks = len(task_results)
assert num_tasks == len(ctx_trackers) // pass_n
Expand All @@ -1044,6 +1068,12 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks,
"mean_reward": sum(rewards) / len(rewards) if rewards else 0,
}
reward_metrics = compute_reward_metrics_from_cmts(ctx_trackers, print_debug=True)
tool_metrics = compute_tool_metrics_from_cmts(ctx_trackers)
if tool_metrics:
val_metrics.update(reward_metrics)
if reward_metrics:
val_metrics.update(tool_metrics)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a copy-paste error in the logic for updating val_metrics. The conditions are correct, but the values being updated are swapped. This will result in tool_metrics being updated with reward_metrics data and vice-versa, leading to incorrect metric reporting.

Suggested change
if tool_metrics:
val_metrics.update(reward_metrics)
if reward_metrics:
val_metrics.update(tool_metrics)
if tool_metrics:
val_metrics.update(tool_metrics)
if reward_metrics:
val_metrics.update(reward_metrics)

print_dict(
val_metrics,
narrow=True,
Expand Down
2 changes: 2 additions & 0 deletions ajet/context_tracker/base_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Tuple, Union
from typing import List, Union, Tuple, Dict, Optional, Any
Comment on lines 1 to +2
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The typing module is imported twice, and the first import is now a subset of the second. To improve readability and avoid redundancy, please consolidate these into a single import statement.

Suggested change
from typing import List, Tuple, Union
from typing import List, Union, Tuple, Dict, Optional, Any
from typing import List, Union, Tuple, Dict, Optional, Any


from ajet.schema.extended_msg import (
INVALID_LOG_PROB_VALUE,
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(self, config, tokenizer, **kwargs):
self.already_mad_flag: bool = False
self.round_cnt = 0
self.generation_prompt_token = None
self.workflow_metadata: Optional[Dict[str, Any]] = None # Initialize workflow_metadata to store tool statistics

assert (
self.config.ajet.data.max_prompt_length
Expand Down
2 changes: 2 additions & 0 deletions ajet/context_tracker/basic_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List:
}
if ext_msg.tool_calls:
d.update({"tool_calls": ext_msg.tool_calls})
if ext_msg.tool_call_id:
d.update({"tool_call_id": ext_msg.tool_call_id})
result.append(d)
return result

Expand Down
83 changes: 65 additions & 18 deletions ajet/context_tracker/multiagent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
from ajet.utils.color_hsl import adjust_color_hsl
from ajet.utils.compute_madness import compute_string_madness
from ajet.utils.tokenizer import ajet_apply_chat_template

#
from ajet.utils.msg_converter import (
convert_grouped_steps_to_openai_format,
convert_ext_msg_to_openai_format,
agentscope_to_openai,
openai_to_agentscope,
agentscope_to_openai_grouped,
openai_to_agentscope_grouped,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These functions (agentscope_to_openai, openai_to_agentscope, agentscope_to_openai_grouped, openai_to_agentscope_grouped) are imported but do not appear to be used within this file. Please remove them to keep the import section clean.

)
@dataclass
class TimelineMergingPolicyConfig:
timeline_compare_level: str = "text"
Expand Down Expand Up @@ -101,27 +109,43 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
author = "env"
ignore = False
str_content = ""

# fix msg content
if msg["content"] is None:
msg["content"] = ""
elif isinstance(msg["content"], list):
for item in msg["content"]:
if "text" not in item:
logger.warning(
f"Non-text content in message content detected: {item}. Ignoring."
)
ignore = True
break
extracted_tool_call_id = ""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable is_tool_result_msg is used on line 146 but is only assigned within the following for loop. If msg["content"] is an empty list, the loop will not execute, and an UnboundLocalError will be raised. Please initialize is_tool_result_msg to False before the loop to prevent this.

                extracted_tool_call_id = ""
                is_tool_result_msg = False

for item_idx, item in enumerate(msg["content"]):
if isinstance(item, dict) and item.get("type") == "tool_result":
is_tool_result_msg = True # 标记为 tool_result 消息
# Extract tool_call_id from the tool_result block
if item.get("id"):
extracted_tool_call_id = item.get("id", "")
output = item.get("output", "")
if isinstance(output, str):
str_content += output
elif isinstance(output, list):
# output can be List[TextBlock | ImageBlock | AudioBlock]
for out_item in output:
if isinstance(out_item, str):
str_content += out_item
elif isinstance(out_item, dict) and "text" in out_item:
str_content += str(out_item["text"])
else:
str_content += str(output)
elif isinstance(item, dict) and "text" in item:
if isinstance(item["text"], str):
str_content += str(item["text"])
else:
str_content = ""
msg["content"] = str_content
else:
raise ValueError(
f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}"
)
else:
logger.warning(
f"Non-text content in message content detected: {item}. Ignoring."
)
ignore = True
break
msg["content"] = str_content
msg["tool_call_id"] = extracted_tool_call_id # Store extracted tool_call_id

# ★ 关键修复:如果是 tool_result 消息,将 role 恢复为 "tool"(OpenAI 格式)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is in Chinese, which is inconsistent with the English comments in the rest of the codebase. For maintainability, please translate it to English or remove it if the code is self-explanatory.

                # Critical fix: If this is a tool_result message, restore the role to "tool" (OpenAI format).

if is_tool_result_msg and extracted_tool_call_id:
msg["role"] = "tool"


if ignore:
continue
Expand All @@ -143,6 +167,7 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to
tokenizer=self.tokenizer,
tools=tools,
tool_calls=(msg["tool_calls"] if "tool_calls" in msg else []),
tool_call_id=(msg["tool_call_id"] if "tool_call_id" in msg else ""),
token_generator="auto",
first_message=(i == 0),
)
Expand Down Expand Up @@ -580,3 +605,25 @@ def check_context_token_num_safe(
else:
ret = (False, token_overflow, "token_overflow")
return ret

def get_grouped_steps_openai_format(self) -> List[List[Dict[str, Any]]]:
"""
将 grouped_steps 转换为 OpenAI 格式并返回。

Returns:
OpenAI 格式的轨迹数据 (List of List of dict)
每条消息格式如:
- {"role": "assistant", "content": "...", "tool_calls": [...]}
- {"role": "tool", "content": "...", "tool_call_id": "call_xxx"}
- {"role": "user/system", "content": "..."}
"""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This docstring is in Chinese, while the rest of the file's documentation is in English. To maintain consistency and improve readability for all contributors, please translate it to English.

        """
        Converts grouped_steps to OpenAI format and returns the result.
        
        Returns:
            Trajectory data in OpenAI format (List of List of dict).
            Each message is formatted as follows:
            - {"role": "assistant", "content": "...", "tool_calls": [...]}
            - {"role": "tool", "content": "...", "tool_call_id": "call_xxx"}
            - {"role": "user/system", "content": "..."}
        """

return convert_grouped_steps_to_openai_format(self.grouped_steps)

def get_full_context_openai_format(self) -> List[Dict[str, Any]]:
"""
将当前 full_context 转换为 OpenAI 格式并返回。

Returns:
OpenAI 格式的消息列表 (List of dict)
"""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This docstring is in Chinese, which is inconsistent with the English documentation in the rest of the file. Please translate it to English for consistency.

        """
        Converts the current full_context to OpenAI format and returns the result.
        
        Returns:
            A list of messages in OpenAI format (List of dict).
        """

return [convert_ext_msg_to_openai_format(msg) for msg in self.full_context]
1 change: 1 addition & 0 deletions ajet/default_config/ajet_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ ajet:
val_pass_n: 4
save_freq: 20
test_freq: 20
save_trajectory: False # whether to save train/eval trajectories to JSON files
total_epochs: 50
nnodes: 1
n_gpus_per_node: 8
Expand Down
17 changes: 14 additions & 3 deletions ajet/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from ajet.utils.pty import pty_launch

set_loguru_default_color()
load_dotenv()
# load_dotenv()
load_dotenv(override=False)


def parse_args():
Expand Down Expand Up @@ -59,6 +60,12 @@ def parse_args():
default=False,
help="Launch appworld",
)
parser.add_argument(
"--with-finworld",
action="store_true",
default=False,
help="Launch finworld",
)
parser.add_argument(
"--with-webshop",
action="store_true",
Expand Down Expand Up @@ -247,8 +254,9 @@ def main():
args = parse_args()

# Enforce GPU availability and free memory threshold before proceeding
if (args.backbone != "debug") and (not args.kill) and (not args.autokill):
check_avail_gpu(min_free_ratio=0.95)
if not args.skip_check_avail_gpu:
if (args.backbone != "debug") and (not args.kill) and (not args.autokill):
check_avail_gpu(min_free_ratio=0.95)

if args.autokill:
args.kill = "ray|vllm|VLLM|python"
Expand Down Expand Up @@ -295,6 +303,9 @@ def main():
if args.with_appworld:
pty_launch("appworld")

if args.with_finworld:
pty_launch("finworld")

if args.with_crafters:
pty_launch("crafters")

Expand Down
4 changes: 4 additions & 0 deletions ajet/schema/extended_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
build_from_uuid="",
tools=[],
tool_calls=[],
tool_call_id="",
token_logprob_arr=[],
first_message=False,
):
Expand All @@ -88,6 +89,7 @@ def __init__(
self.clip = clip
self.tools = tools
self.tool_calls = tool_calls
self.tool_call_id = tool_call_id
self.uuid = uuid.uuid4().hex
self.build_from_uuid = build_from_uuid
self.first_message = first_message
Expand Down Expand Up @@ -143,6 +145,8 @@ def auto_tokenize_non_first_message(self, tokenizer, tools):
}
if self.tool_calls:
auto_tokenize_target.update({"tool_calls": self.tool_calls})
if self.tool_call_id:
auto_tokenize_target.update({"tool_call_id": self.tool_call_id})
text_frag_to = ajet_apply_chat_template(
tokenizer=tokenizer,
conversation=DUMMY_MSG + [auto_tokenize_target],
Expand Down
20 changes: 14 additions & 6 deletions ajet/task_rollout/resource_keeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def _initialize_environment_and_messages(self) -> List[dict]:
params=self.env_params,
)
state_message: dict = init_response["state"]
_, init_messages = self._get_init_messages(state_message)
query, init_messages = self._get_init_messages(state_message)
# Update main_query with actual query from environment
self.workflow_task.task.main_query = query
except Exception as e:
logger.bind(exception=True).exception(
f"encounter exception in env_worker.create_instance~ error={e.args}"
Expand Down Expand Up @@ -175,12 +177,18 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]:
)
obs = ""
assert isinstance(env_output, dict)
if ("content" not in env_output["state"]) and ("error" in env_output["state"]):
obs = f"[Error from environment: {env_output['error']}]"
elif env_output["state"]["content"] == "":
obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again."
# === Support list-type state passthrough ===
# 1. If state is a list (new standard format), pass through directly
if isinstance(env_output["state"], list):
obs = env_output["state"]
# 2. If state is a dict (old format or error)
else:
obs = env_output["state"]["content"]
if ("content" not in env_output["state"]) and ("error" in env_output["state"]):
obs = f"[Error from environment: {env_output['error']}]"
elif env_output["state"].get("content", "") == "":
obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again."
else:
obs = env_output["state"]["content"]
reward = 0
info = {}
terminate = env_output["is_terminated"]
Expand Down
22 changes: 22 additions & 0 deletions ajet/task_runner/general_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from venv import logger
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This import is incorrect and will cause a critical ImportError at runtime. The venv module is for creating Python virtual environments and does not contain a logger. You likely intended to import from loguru.

Suggested change
from venv import logger
from loguru import logger


from ajet import AjetTuner
from ajet import Workflow, WorkflowOutput
Expand Down Expand Up @@ -49,13 +50,34 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
workflow_output: WorkflowOutput = asyncio.run(
user_workflow.execute(workflow_task, tuner)
)
# set workflow metadata to context tracker metadata
context_tracker.workflow_metadata = workflow_output.metadata
if workflow_output.reward is not None:
raw_reward, is_success = (
workflow_output.reward,
workflow_output.is_success,
)
else:
raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output)

# ✅ Critical Fix: After calling `judge`, write the updated `reward_stats` back to `workflow_metadata`
# # Ensure that `native_compat_trainer` reads the actual value calculated by `judge`, not the 0 value returned by `env`.
if workflow_output.metadata and 'reward_stats' in workflow_output.metadata:
context_tracker.workflow_metadata['reward_stats'] = workflow_output.metadata['reward_stats']
else:
# fallback: If the judge does not update reward_stats, use the default value.
logger.warning(f"[WARN] reward_stats not found in metadata after judge call, creating default values")
default_reward_stats = {
'original_reward': raw_reward,
'penalty': 0.0,
'step_reward': raw_reward,
}
if workflow_output.metadata:
workflow_output.metadata['reward_stats'] = default_reward_stats
context_tracker.workflow_metadata['reward_stats'] = default_reward_stats
else:
context_tracker.workflow_metadata = {'reward_stats': default_reward_stats}

workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue

assert not isinstance(
Expand Down
7 changes: 7 additions & 0 deletions ajet/utils/core_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def get_runtime_env(is_trinity: bool = False) -> dict:
"AJET_GIT_HASH",
"AJET_REQ_TXT",
"AJET_BENCHMARK_NAME",
"FINANCE_MCP_URL",
# API Keys for RM Gallery and other services
"DASHSCOPE_API_KEY",
"OPENAI_API_KEY",
"OPENAI_BASE_URL",
"API_KEY",
"BASE_URL",
]

for var in optional_env_vars:
Expand Down
Loading
Loading