-
Notifications
You must be signed in to change notification settings - Fork 19
feat: Add comprehensive metric tracking and trajectory persistence #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
0a9a213
96b5071
697bc5b
cccc13f
f52d83e
d930ffb
b548324
836e380
5977031
c10e96a
a7dc96d
69ca4d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||
| from pprint import pprint | ||||||||||||||||||
| from typing import List, Optional | ||||||||||||||||||
| from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_cmts | ||||||||||||||||||
| from loguru import logger as loguru_logger | ||||||||||||||||||
|
|
||||||||||||||||||
| import hydra | ||||||||||||||||||
| import numpy as np | ||||||||||||||||||
|
|
@@ -54,7 +56,14 @@ | |||||||||||||||||
| from ajet.schema.task import Task | ||||||||||||||||||
| from ajet.task_reader import dict_to_ajet_task | ||||||||||||||||||
| from ajet.task_rollout.native_parallel_worker import VerlRolloutManager | ||||||||||||||||||
|
|
||||||||||||||||||
| from ajet.utils.save_trajectory import save_train_trajectory, save_eval_trajectory | ||||||||||||||||||
| from ajet.utils.msg_converter import ( | ||||||||||||||||||
| convert_grouped_steps_to_openai_format, | ||||||||||||||||||
| convert_ext_msg_to_openai_format, | ||||||||||||||||||
| agentscope_to_openai, | ||||||||||||||||||
| openai_to_agentscope, | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||
| ) | ||||||||||||||||||
| from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_cmts | ||||||||||||||||||
|
|
||||||||||||||||||
| def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor: | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
@@ -577,6 +586,9 @@ def fit(self): # noqa: C901 | |||||||||||||||||
| tasks, mode="sample", epoch=f"train.{epoch}" | ||||||||||||||||||
| ) | ||||||||||||||||||
| logger.info("=" * 10 + "end fit rollout" + "=" * 10) | ||||||||||||||||||
|
|
||||||||||||||||||
| if self.config.ajet.trainer_common.save_trajectory: | ||||||||||||||||||
| save_train_trajectory(context_tracker_arr, self.global_steps) | ||||||||||||||||||
| logger.info("begin to convert context_tracker_arr to dataproto") | ||||||||||||||||||
| gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr) | ||||||||||||||||||
| logger.info("end convertion") | ||||||||||||||||||
|
|
@@ -602,6 +614,14 @@ def fit(self): # noqa: C901 | |||||||||||||||||
| ), | ||||||||||||||||||
| } | ||||||||||||||||||
| ) | ||||||||||||||||||
| from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_trajectories | ||||||||||||||||||
| from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_trajectories | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||
| tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr) | ||||||||||||||||||
| reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr) | ||||||||||||||||||
| if tool_metrics: | ||||||||||||||||||
| metrics.update(tool_metrics) | ||||||||||||||||||
| if reward_metrics: | ||||||||||||||||||
| metrics.update(reward_metrics) | ||||||||||||||||||
| if self.config.ajet.execute_test: # apply a test probe | ||||||||||||||||||
| from swanlab.data.run.main import get_run | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -1029,6 +1049,10 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): | |||||||||||||||||
| for ctx_tracker in ctx_trackers: | ||||||||||||||||||
| ctx_tracker.generate_log() | ||||||||||||||||||
|
|
||||||||||||||||||
| # save eval trajectories | ||||||||||||||||||
| if self.config.ajet.trainer_common.save_trajectory: | ||||||||||||||||||
| save_eval_trajectory(ctx_trackers, self.global_steps) | ||||||||||||||||||
|
|
||||||||||||||||||
| rewards = [ctx_tracker.reward_structure.raw_reward for ctx_tracker in ctx_trackers] | ||||||||||||||||||
| num_tasks = len(task_results) | ||||||||||||||||||
| assert num_tasks == len(ctx_trackers) // pass_n | ||||||||||||||||||
|
|
@@ -1044,6 +1068,12 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): | |||||||||||||||||
| f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks, | ||||||||||||||||||
| "mean_reward": sum(rewards) / len(rewards) if rewards else 0, | ||||||||||||||||||
| } | ||||||||||||||||||
| reward_metrics = compute_reward_metrics_from_cmts(ctx_trackers, print_debug=True) | ||||||||||||||||||
| tool_metrics = compute_tool_metrics_from_cmts(ctx_trackers) | ||||||||||||||||||
| if tool_metrics: | ||||||||||||||||||
| val_metrics.update(reward_metrics) | ||||||||||||||||||
| if reward_metrics: | ||||||||||||||||||
| val_metrics.update(tool_metrics) | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a copy-paste error in the logic for updating
Suggested change
|
||||||||||||||||||
| print_dict( | ||||||||||||||||||
| val_metrics, | ||||||||||||||||||
| narrow=True, | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,5 @@ | ||||||||
| from typing import List, Tuple, Union | ||||||||
| from typing import List, Union, Tuple, Dict, Optional, Any | ||||||||
|
Comment on lines
1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||
|
|
||||||||
| from ajet.schema.extended_msg import ( | ||||||||
| INVALID_LOG_PROB_VALUE, | ||||||||
|
|
@@ -135,6 +136,7 @@ def __init__(self, config, tokenizer, **kwargs): | |||||||
| self.already_mad_flag: bool = False | ||||||||
| self.round_cnt = 0 | ||||||||
| self.generation_prompt_token = None | ||||||||
| self.workflow_metadata: Optional[Dict[str, Any]] = None # Initialize workflow_metadata to store tool statistics | ||||||||
|
|
||||||||
| assert ( | ||||||||
| self.config.ajet.data.max_prompt_length | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,15 @@ | |
| from ajet.utils.color_hsl import adjust_color_hsl | ||
| from ajet.utils.compute_madness import compute_string_madness | ||
| from ajet.utils.tokenizer import ajet_apply_chat_template | ||
|
|
||
| # | ||
| from ajet.utils.msg_converter import ( | ||
| convert_grouped_steps_to_openai_format, | ||
| convert_ext_msg_to_openai_format, | ||
| agentscope_to_openai, | ||
| openai_to_agentscope, | ||
| agentscope_to_openai_grouped, | ||
| openai_to_agentscope_grouped, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| ) | ||
| @dataclass | ||
| class TimelineMergingPolicyConfig: | ||
| timeline_compare_level: str = "text" | ||
|
|
@@ -101,27 +109,43 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to | |
| author = "env" | ||
| ignore = False | ||
| str_content = "" | ||
|
|
||
| # fix msg content | ||
| if msg["content"] is None: | ||
| msg["content"] = "" | ||
| elif isinstance(msg["content"], list): | ||
| for item in msg["content"]: | ||
| if "text" not in item: | ||
| logger.warning( | ||
| f"Non-text content in message content detected: {item}. Ignoring." | ||
| ) | ||
| ignore = True | ||
| break | ||
| extracted_tool_call_id = "" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable extracted_tool_call_id = ""
is_tool_result_msg = False |
||
| for item_idx, item in enumerate(msg["content"]): | ||
| if isinstance(item, dict) and item.get("type") == "tool_result": | ||
| is_tool_result_msg = True # 标记为 tool_result 消息 | ||
| # Extract tool_call_id from the tool_result block | ||
| if item.get("id"): | ||
| extracted_tool_call_id = item.get("id", "") | ||
| output = item.get("output", "") | ||
| if isinstance(output, str): | ||
| str_content += output | ||
| elif isinstance(output, list): | ||
| # output can be List[TextBlock | ImageBlock | AudioBlock] | ||
| for out_item in output: | ||
| if isinstance(out_item, str): | ||
| str_content += out_item | ||
| elif isinstance(out_item, dict) and "text" in out_item: | ||
| str_content += str(out_item["text"]) | ||
| else: | ||
| str_content += str(output) | ||
| elif isinstance(item, dict) and "text" in item: | ||
| if isinstance(item["text"], str): | ||
| str_content += str(item["text"]) | ||
| else: | ||
| str_content = "" | ||
| msg["content"] = str_content | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}" | ||
| ) | ||
| else: | ||
| logger.warning( | ||
| f"Non-text content in message content detected: {item}. Ignoring." | ||
| ) | ||
| ignore = True | ||
| break | ||
| msg["content"] = str_content | ||
| msg["tool_call_id"] = extracted_tool_call_id # Store extracted tool_call_id | ||
|
|
||
| # ★ 关键修复:如果是 tool_result 消息,将 role 恢复为 "tool"(OpenAI 格式) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if is_tool_result_msg and extracted_tool_call_id: | ||
| msg["role"] = "tool" | ||
|
|
||
|
|
||
| if ignore: | ||
| continue | ||
|
|
@@ -143,6 +167,7 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to | |
| tokenizer=self.tokenizer, | ||
| tools=tools, | ||
| tool_calls=(msg["tool_calls"] if "tool_calls" in msg else []), | ||
| tool_call_id=(msg["tool_call_id"] if "tool_call_id" in msg else ""), | ||
| token_generator="auto", | ||
| first_message=(i == 0), | ||
| ) | ||
|
|
@@ -580,3 +605,25 @@ def check_context_token_num_safe( | |
| else: | ||
| ret = (False, token_overflow, "token_overflow") | ||
| return ret | ||
|
|
||
| def get_grouped_steps_openai_format(self) -> List[List[Dict[str, Any]]]: | ||
| """ | ||
| 将 grouped_steps 转换为 OpenAI 格式并返回。 | ||
|
|
||
| Returns: | ||
| OpenAI 格式的轨迹数据 (List of List of dict) | ||
| 每条消息格式如: | ||
| - {"role": "assistant", "content": "...", "tool_calls": [...]} | ||
| - {"role": "tool", "content": "...", "tool_call_id": "call_xxx"} | ||
| - {"role": "user/system", "content": "..."} | ||
| """ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This docstring is in Chinese, while the rest of the file's documentation is in English. To maintain consistency and improve readability for all contributors, please translate it to English. """
Converts grouped_steps to OpenAI format and returns the result.
Returns:
Trajectory data in OpenAI format (List of List of dict).
Each message is formatted as follows:
- {"role": "assistant", "content": "...", "tool_calls": [...]}
- {"role": "tool", "content": "...", "tool_call_id": "call_xxx"}
- {"role": "user/system", "content": "..."}
""" |
||
| return convert_grouped_steps_to_openai_format(self.grouped_steps) | ||
|
|
||
| def get_full_context_openai_format(self) -> List[Dict[str, Any]]: | ||
| """ | ||
| 将当前 full_context 转换为 OpenAI 格式并返回。 | ||
|
|
||
| Returns: | ||
| OpenAI 格式的消息列表 (List of dict) | ||
| """ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return [convert_ext_msg_to_openai_format(msg) for msg in self.full_context] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import asyncio | ||
| from venv import logger | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| from ajet import AjetTuner | ||
| from ajet import Workflow, WorkflowOutput | ||
|
|
@@ -49,13 +50,34 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: | |
| workflow_output: WorkflowOutput = asyncio.run( | ||
| user_workflow.execute(workflow_task, tuner) | ||
| ) | ||
| # set workflow metadata to context tracker metadata | ||
| context_tracker.workflow_metadata = workflow_output.metadata | ||
| if workflow_output.reward is not None: | ||
| raw_reward, is_success = ( | ||
| workflow_output.reward, | ||
| workflow_output.is_success, | ||
| ) | ||
| else: | ||
| raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output) | ||
|
|
||
| # ✅ Critical Fix: After calling `judge`, write the updated `reward_stats` back to `workflow_metadata` | ||
| # # Ensure that `native_compat_trainer` reads the actual value calculated by `judge`, not the 0 value returned by `env`. | ||
| if workflow_output.metadata and 'reward_stats' in workflow_output.metadata: | ||
| context_tracker.workflow_metadata['reward_stats'] = workflow_output.metadata['reward_stats'] | ||
| else: | ||
| # fallback: If the judge does not update reward_stats, use the default value. | ||
| logger.warning(f"[WARN] reward_stats not found in metadata after judge call, creating default values") | ||
| default_reward_stats = { | ||
| 'original_reward': raw_reward, | ||
| 'penalty': 0.0, | ||
| 'step_reward': raw_reward, | ||
| } | ||
| if workflow_output.metadata: | ||
| workflow_output.metadata['reward_stats'] = default_reward_stats | ||
| context_tracker.workflow_metadata['reward_stats'] = default_reward_stats | ||
| else: | ||
| context_tracker.workflow_metadata = {'reward_stats': default_reward_stats} | ||
|
|
||
| workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue | ||
|
|
||
| assert not isinstance( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import of
loguru_loggeris redundant asloggeris already imported fromloguruon line 27 and this alias is not used. Please remove it to keep the imports clean.