diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 60c9d6ed..3cbb08ad 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -54,7 +54,7 @@ from ajet.schema.task import Task from ajet.task_reader import dict_to_ajet_task from ajet.task_rollout.native_parallel_worker import VerlRolloutManager - +from ajet.utils.metric_helper import save_trajectory_as_json_file, update_metrics def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor: """ @@ -602,6 +602,8 @@ def fit(self): # noqa: C901 ), } ) + save_trajectory_as_json_file(context_tracker_arr, self.global_steps, self.config, prefix="train") + update_metrics(context_tracker_arr, metrics) if self.config.ajet.execute_test: # apply a test probe from swanlab.data.run.main import get_run @@ -1044,6 +1046,8 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch): f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks, "mean_reward": sum(rewards) / len(rewards) if rewards else 0, } + save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval") + update_metrics(ctx_trackers, val_metrics) print_dict( val_metrics, narrow=True, diff --git a/ajet/context_tracker/base_tracker.py b/ajet/context_tracker/base_tracker.py index c9244d7b..0ff706fa 100644 --- a/ajet/context_tracker/base_tracker.py +++ b/ajet/context_tracker/base_tracker.py @@ -1,4 +1,6 @@ from typing import List, Tuple, Union +from typing import List, Union, Tuple, Dict, Optional, Any +from ajet.schema.task import WorkflowTask from ajet.schema.extended_msg import ( INVALID_LOG_PROB_VALUE, @@ -110,10 +112,14 @@ def replace_token_ids( class BaseTracker(object): - def __init__(self, config, tokenizer, **kwargs): - self.task_batch_index = kwargs.get("task_batch_index", "undefined") - self.task_tag = kwargs.get("task_tag", "undefined") - self.task_id = kwargs.get("task_id", "undefined") + def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs): + + self.workflow_task = workflow_task + self.task_batch_index = self.workflow_task.task_batch_index + self.task_tag = self.workflow_task.task_tag + self.task_id = self.workflow_task.task_id + self.episode_uuid = self.workflow_task.episode_uuid + self.config = config self.tokenizer = tokenizer self.saved_timelines: List[List[ExtendedMessage]] = [] @@ -135,6 +141,7 @@ def __init__(self, config, tokenizer, **kwargs): self.already_mad_flag: bool = False self.round_cnt = 0 self.generation_prompt_token = None + self.log_metrics: Optional[Dict[str, Union[float, List[float]]]] = None # Initialize workflow_metadata to store tool statistics assert ( self.config.ajet.data.max_prompt_length diff --git a/ajet/context_tracker/basic_tracker.py b/ajet/context_tracker/basic_tracker.py index 100e5d2c..44d81cb7 100644 --- a/ajet/context_tracker/basic_tracker.py +++ b/ajet/context_tracker/basic_tracker.py @@ -192,6 +192,8 @@ def to_role_content(self, ext_msg_array: List[ExtendedMessage]) -> List: } if ext_msg.tool_calls: d.update({"tool_calls": ext_msg.tool_calls}) + if ext_msg.tool_call_id: + d.update({"tool_call_id": ext_msg.tool_call_id}) result.append(d) return result diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index fce83a91..4dd75ee1 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -50,7 +50,6 @@ def __init__( config, should_interrupt_fn, generated_token_callback_fn, - episode_uuid: str, **kwargs, ): super().__init__(config, tokenizer, **kwargs) @@ -61,7 +60,6 @@ def __init__( self.output_kwargs = {} self.input_kwargs = {} self.timeline_cache = {} - self.episode_uuid = episode_uuid def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool = False): @@ -74,6 +72,40 @@ def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool tools[i]["function"]["parameters"] = tools[i]["function"].pop("parameters") return tools + def extract_text_content_from_content_dict(self, msg): + # msg = { + # "role": "assistant", + # "content": [ + # { + # "type": "text", + # "text": "some text" + # }, + # ], + # } + + str_content = "" + for item in msg["content"]: + # item = { + # "type": "text", + # "text": "some text" + # }, + + assert isinstance(item, dict), f"Unsupported non-dict item in message content: {item}. Full message: {msg}" + + if ("text" not in item): + logger.warning( + f"Non-text content in message content detected: {item}. Ignoring." + ) + should_skip_message = True + return str_content, should_skip_message + + if isinstance(item["text"], str): + str_content += str(item["text"]) + else: + str_content = "" + + should_skip_message = False + return str_content, should_skip_message def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_toolcalls: bool = False) -> List[ExtendedMessage]: """Spawn a timeline from messages. @@ -93,39 +125,32 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to consider_roles.remove("tool") for i, msg in enumerate(messages): + if (disable_toolcalls) and (not isinstance(msg["content"], str)): continue + if msg["role"] not in consider_roles: continue + if not isinstance(msg["content"], str): author = "env" - ignore = False - str_content = "" + should_skip_message = False # fix msg content if msg["content"] is None: msg["content"] = "" + elif isinstance(msg["content"], list): - for item in msg["content"]: - if "text" not in item: - logger.warning( - f"Non-text content in message content detected: {item}. Ignoring." - ) - ignore = True - break - if isinstance(item["text"], str): - str_content += str(item["text"]) - else: - str_content = "" - msg["content"] = str_content + msg["content"], should_skip_message = self.extract_text_content_from_content_dict(msg) + else: - raise ValueError( - f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}" - ) + raise ValueError(f"Unsupported non-str message content type: {type(msg['content'])}, Message:\n {msg}") - if ignore: + if should_skip_message: continue - msg["content"] = str(msg["content"]) # TODO: better handling mm data + + if not isinstance(msg["content"], str): + msg["content"] = str(msg["content"]) # TODO: better handling mm data if msg["role"] == "system": author = "initialization" @@ -143,7 +168,9 @@ def step_spawn_timeline(self, messages: List[dict], tools: List = [], disable_to tokenizer=self.tokenizer, tools=tools, tool_calls=(msg["tool_calls"] if "tool_calls" in msg else []), + tool_call_id=(msg["tool_call_id"] if "tool_call_id" in msg else ""), token_generator="auto", + name = (msg["name"] if "name" in msg else ""), first_message=(i == 0), ) ] diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 33a7ea1f..ad4d6d50 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -231,29 +231,57 @@ ajet: # trainer common configurations trainer_common: + + # validation before training val_before_train: False val_pass_n: 4 + + # save and test frequency (in step) save_freq: 20 test_freq: 20 + + # total training epochs total_epochs: 50 + nnodes: 1 n_gpus_per_node: 8 + + # logger selection logger: swanlab + + # algorithm setting algorithm: adv_estimator: grpo use_kl_in_reward: False + + # number of optimizer.step per big batch mini_batch_num: 1 + + # verl offload configs fsdp_config: param_offload: True optimizer_offload: True + + # learning rate optim: lr: 1e-6 + + # enable KL loss regularization use_kl_loss: True + + # kl divergence loss coefficient kl_loss_coef: 0.002 kl_loss_type: low_var_kl + + # Ulysses specific configs ulysses_sequence_parallel_size: 1 + + # base directory to save checkpoints checkpoint_base_dir: ./saved_checkpoints + # whether to save train/eval trajectories to JSON files + save_trajectory_as_json_file: False + diff --git a/ajet/launcher.py b/ajet/launcher.py index 0b7ea94b..73a347aa 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -16,7 +16,7 @@ from ajet.utils.pty import pty_launch set_loguru_default_color() -load_dotenv() +load_dotenv(override=False) def parse_args(): @@ -59,6 +59,12 @@ def parse_args(): default=False, help="Launch appworld", ) + parser.add_argument( + "--with-finworld", + action="store_true", + default=False, + help="Launch finworld", + ) parser.add_argument( "--with-webshop", action="store_true", @@ -79,6 +85,7 @@ def parse_args(): help="Launch Crafters Env Simulation", ) parser.add_argument("--reboot", action="store_true", default=False, help="reboot flag") + parser.add_argument("--skip-check-avail-gpu", action="store_true", default=False, help="Skip GPU availability check") parser.add_argument( "--kill", type=str, @@ -247,8 +254,9 @@ def main(): args = parse_args() # Enforce GPU availability and free memory threshold before proceeding - if (args.backbone != "debug") and (not args.kill) and (not args.autokill): - check_avail_gpu(min_free_ratio=0.95) + if not args.skip_check_avail_gpu: + if (args.backbone != "debug") and (not args.kill) and (not args.autokill): + check_avail_gpu(min_free_ratio=0.95) if args.autokill: args.kill = "ray|vllm|VLLM|python" @@ -295,6 +303,9 @@ def main(): if args.with_appworld: pty_launch("appworld") + if args.with_finworld: + pty_launch("finworld") + if args.with_crafters: pty_launch("crafters") diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py index f3d5a59e..dfaa7460 100644 --- a/ajet/schema/extended_msg.py +++ b/ajet/schema/extended_msg.py @@ -72,7 +72,9 @@ def __init__( build_from_uuid="", tools=[], tool_calls=[], + tool_call_id="", token_logprob_arr=[], + name="", # preserved field, not used currently first_message=False, ): self.author = author @@ -88,6 +90,8 @@ def __init__( self.clip = clip self.tools = tools self.tool_calls = tool_calls + self.tool_call_id = tool_call_id + self.name = name # preserved field, not used currently if not isinstance(self.tool_calls, list): # agent scope sometimes gives weird type for tool_calls, which is against OpenAI schema self.tool_calls = list(self.tool_calls) @@ -146,6 +150,8 @@ def auto_tokenize_non_first_message(self, tokenizer, tools): } if self.tool_calls: auto_tokenize_target.update({"tool_calls": self.tool_calls}) + if self.tool_call_id: + auto_tokenize_target.update({"tool_call_id": self.tool_call_id}) text_frag_to = ajet_apply_chat_template( tokenizer=tokenizer, conversation=DUMMY_MSG + [auto_tokenize_target], diff --git a/ajet/schema/task.py b/ajet/schema/task.py index f93cff97..6d94796c 100644 --- a/ajet/schema/task.py +++ b/ajet/schema/task.py @@ -43,3 +43,4 @@ class WorkflowOutput(BaseModel): reward: Union[float, List[float], None] = Field(default=None) is_success: Union[bool, None] = Field(default=None) metadata: Dict[str, Any] = Field(default_factory=dict) + log_metrics: Dict[str, Union[float, List[float]]] = Field(default_factory=dict) diff --git a/ajet/task_rollout/resource_keeper.py b/ajet/task_rollout/resource_keeper.py index e2ec4c14..2498b415 100644 --- a/ajet/task_rollout/resource_keeper.py +++ b/ajet/task_rollout/resource_keeper.py @@ -85,7 +85,9 @@ def _initialize_environment_and_messages(self) -> List[dict]: params=self.env_params, ) state_message: dict = init_response["state"] - _, init_messages = self._get_init_messages(state_message) + query, init_messages = self._get_init_messages(state_message) + # Update main_query with actual query from environment + self.workflow_task.task.main_query = query except Exception as e: logger.bind(exception=True).exception( f"encounter exception in env_worker.create_instance~ error={e.args}" @@ -176,16 +178,23 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]: ) obs = "" assert isinstance(env_output, dict) - if ("content" not in env_output["state"]) and ("error" in env_output["state"]): - obs = f"[Error from environment: {env_output['error']}]" - elif env_output["state"]["content"] == "": - obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again." + + if isinstance(env_output["state"], list): + # 1. If state is a list (new standard format), pass through directly + obs = env_output["state"] else: - obs = env_output["state"]["content"] + # 2. If state is a dict (old format or error) + if ("content" not in env_output["state"]) and ("error" in env_output["state"]): + obs = f"[Error from environment: {env_output['error']}]" + elif env_output["state"].get("content", "") == "": + obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again." + else: + obs = env_output["state"]["content"] + reward = 0 info = {} terminate = env_output["is_terminated"] - return obs, reward, terminate, info + return obs, reward, terminate, info # type: ignore def reset(self) -> str: """Reset gym environment.""" diff --git a/ajet/task_runner/general_runner.py b/ajet/task_runner/general_runner.py index e88ec323..94271cbd 100644 --- a/ajet/task_runner/general_runner.py +++ b/ajet/task_runner/general_runner.py @@ -1,4 +1,5 @@ import asyncio +from venv import logger from ajet import AjetTuner from ajet import Workflow, WorkflowOutput @@ -16,9 +17,6 @@ class GeneralRunner(BaseAgentRunner): def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: observation_window = workflow_task.observation_window task_thread_index = workflow_task.task_thread_index - task_batch_index = workflow_task.task_batch_index - task_tag = workflow_task.task_tag - task_id = workflow_task.task_id workflow_import = self.config.ajet.rollout.user_workflow workflow_cls = dynamic_import(workflow_import) @@ -33,10 +31,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: llm_inference_fn=self.llm_inference_fn, tokenizer=self.tokenizer, config=self.config, - task_batch_index=task_batch_index, - task_tag=task_tag, - task_id=task_id, - episode_uuid=workflow_task.episode_uuid, + workflow_task = workflow_task, **hooks, ) tuner = AjetTuner( @@ -45,7 +40,6 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: user_workflow=user_workflow, config=self.config, ) - workflow_output: WorkflowOutput = asyncio.run( user_workflow.execute(workflow_task, tuner) ) @@ -56,6 +50,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: ) else: raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output) + workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue assert not isinstance( @@ -73,12 +68,11 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker: ) context_tracker.process_reward(reward) # generate token before merging - context_tracker.task_id = task_id - context_tracker.task_tag = task_tag context_tracker.group_merge() # after merging, process and align reward again context_tracker.process_reward(reward) # mark the thread as ended observation_window["step"][task_thread_index] = -1 tuner.terminate_episode() + context_tracker.log_metrics = workflow_output.log_metrics return context_tracker diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index 6c96358e..d134b568 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -40,6 +40,13 @@ def get_runtime_env(is_trinity: bool = False) -> dict: "AJET_GIT_HASH", "AJET_REQ_TXT", "AJET_BENCHMARK_NAME", + "FINANCE_MCP_URL", + # API Keys for RM Gallery and other services + "DASHSCOPE_API_KEY", + "OPENAI_API_KEY", + "OPENAI_BASE_URL", + "API_KEY", + "BASE_URL", ] for var in optional_env_vars: diff --git a/ajet/utils/metric_helper/__init__.py b/ajet/utils/metric_helper/__init__.py new file mode 100644 index 00000000..a9702d5d --- /dev/null +++ b/ajet/utils/metric_helper/__init__.py @@ -0,0 +1,17 @@ +from ajet.utils.metric_helper.save_trajectory_as_json import save_trajectory_as_json +from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_trajectories +from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_trajectories + + +def save_trajectory_as_json_file(ctx_trackers, global_steps, config, prefix): + if config.ajet.trainer_common.save_trajectory_as_json_file: + save_trajectory_as_json(ctx_trackers, global_steps, prefix) + +def update_metrics(context_tracker_arr, metrics:dict): + tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr) + reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr) + if tool_metrics: + metrics.update(tool_metrics) + if reward_metrics: + metrics.update(reward_metrics) + return \ No newline at end of file diff --git a/ajet/utils/metric_helper/reward_metric_helper.py b/ajet/utils/metric_helper/reward_metric_helper.py new file mode 100644 index 00000000..b6cf5918 --- /dev/null +++ b/ajet/utils/metric_helper/reward_metric_helper.py @@ -0,0 +1,231 @@ +""" +FinWorld Reward Metrics Helper + +Provides standalone utility functions for reward_stats extraction and SwanLab metrics formatting. +Decouples finworld-specific logic from core code, reducing intrusion into native_compat_trainer. + +SwanLab metrics directory structure: +- rewards/ Top-level aggregated scores +- rewards/dimensions/ Raw scores (unweighted) +- rewards/contribution/ Weighted contributions +- judge_time/ Judge time consumption statistics +""" + +from typing import List, Dict, Any, Optional +import numpy as np + + +def extract_reward_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]: + """ + Extract reward_stats from trajectories list. + + Args: + trajectories: List of trajectory objects containing workflow_metadata + + Returns: + List of reward_stats dictionaries + """ + reward_stats_list = [] + for traj in trajectories: + if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: + if 'reward_stats' in traj.workflow_metadata: + reward_stats_list.append(traj.workflow_metadata['reward_stats']) + return reward_stats_list + + +def extract_reward_stats_from_cmts(cmts: List[Any]) -> tuple[List[Dict[str, Any]], Dict[str, int]]: + """ + Extract reward_stats from cmts list and return debug statistics. + + Args: + cmts: List of cmt objects containing workflow_metadata + + Returns: + Tuple of (reward_stats_list, debug_stats) + """ + reward_stats_list = [] + debug_stats = { + 'total_cmts': len(cmts), + 'has_workflow_metadata': 0, + 'has_reward_stats': 0, + } + + for _cmt in cmts: + if hasattr(_cmt, 'workflow_metadata') and _cmt.workflow_metadata: + debug_stats['has_workflow_metadata'] += 1 + if 'reward_stats' in _cmt.workflow_metadata: + debug_stats['has_reward_stats'] += 1 + reward_stats_list.append(_cmt.workflow_metadata['reward_stats']) + + return reward_stats_list, debug_stats + + +def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: + """ + Compute SwanLab metrics from reward_stats list. + + Supports two data sources: + 1. RM Gallery RewardStats fields (rm_raw, etc.) + 2. OpenJudge fields (openjudge_xxx_raw, openjudge_xxx_contribution, etc.) + + Args: + reward_stats_list: List of reward_stats dictionaries + prefix: Metric name prefix (e.g., "val/" for validation phase) + + Returns: + Formatted metrics dictionary ready for SwanLab reporting + """ + if not reward_stats_list: + return {} + + n = len(reward_stats_list) + metrics = {} + + # ========== Top-level Scores (General) ========== + final_reward_list = [rs.get('final_reward', 0.0) for rs in reward_stats_list] + fused_reward_list = [rs.get('fused_reward', 0.0) for rs in reward_stats_list] + penalty_list = [rs.get('penalty', 0.0) for rs in reward_stats_list] + step_reward_list = [rs.get('step_reward', 0.0) for rs in reward_stats_list] + + # Penalty statistics + non_zero_penalties = [p for p in penalty_list if p != 0.0] + + # Top-level metrics + metrics[f"{prefix}rewards/final_reward_mean"] = float(np.mean(final_reward_list)) + metrics[f"{prefix}rewards/fused_reward_mean"] = float(np.mean(fused_reward_list)) + metrics[f"{prefix}rewards/penalty_mean"] = float(np.mean(penalty_list)) + metrics[f"{prefix}rewards/step_reward_mean"] = float(np.mean(step_reward_list)) + metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties) + metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0 + + # ========== Detect OpenJudge Usage ========== + openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('openjudge_enabled', False)) + + if openjudge_enabled_count > 0: + # ========== OpenJudge Metrics ========== + metrics[f"{prefix}rewards/openjudge_enabled_rate"] = openjudge_enabled_count / n * 100 + + # Dynamically extract OpenJudge grader fields + # Currently supported graders: report_resolution, trajectory_faithfulness, + # rubrics_performance, trajectory_comprehensive, information_gain, action_loop + openjudge_graders = [ + "report_resolution", + "trajectory_faithfulness", + "rubrics_performance", + "trajectory_comprehensive", + "information_gain", + "action_loop", + ] + + for grader_name in openjudge_graders: + raw_key = f"openjudge_{grader_name}_raw" + contrib_key = f"openjudge_{grader_name}_contribution" + + raw_list = [rs.get(raw_key, 0.0) for rs in reward_stats_list] + contrib_list = [rs.get(contrib_key, 0.0) for rs in reward_stats_list] + + # Only report when non-zero values exist + if any(v != 0.0 for v in raw_list): + metrics[f"{prefix}rewards/openjudge/{grader_name}_raw_mean"] = float(np.mean(raw_list)) + if any(v != 0.0 for v in contrib_list): + metrics[f"{prefix}rewards/openjudge/{grader_name}_contribution_mean"] = float(np.mean(contrib_list)) + + # OpenJudge time consumption statistics + grading_time_list = [rs.get('grading_time', 0.0) for rs in reward_stats_list] + if any(v != 0.0 for v in grading_time_list): + metrics[f"{prefix}judge_time/openjudge_grading_time_mean"] = float(np.mean(grading_time_list)) + metrics[f"{prefix}judge_time/openjudge_grading_time_max"] = float(np.max(grading_time_list)) + + # ========== RM Gallery Metrics ========== + + # RM Gallery + rm_raw_list = [rs.get('rm_raw', 0.0) for rs in reward_stats_list] + rm_contribution_list = [rs.get('rm_contribution', 0.0) for rs in reward_stats_list] + + # RefJudge + ref_final_raw_list = [rs.get('ref_final_raw', 0.0) for rs in reward_stats_list] + ref_citation_raw_list = [rs.get('ref_citation_raw', 0.0) for rs in reward_stats_list] + ref_grounding_raw_list = [rs.get('ref_grounding_raw', 0.0) for rs in reward_stats_list] + ref_contribution_list = [rs.get('ref_contribution', 0.0) for rs in reward_stats_list] + + # StructureJudge + structure_raw_list = [rs.get('structure_raw', 0.0) for rs in reward_stats_list] + structure_contribution_list = [rs.get('structure_contribution', 0.0) for rs in reward_stats_list] + + # dimensions/ raw scores + metrics[f"{prefix}rewards/dimensions/rm_raw_mean"] = float(np.mean(rm_raw_list)) + metrics[f"{prefix}rewards/dimensions/ref_final_raw_mean"] = float(np.mean(ref_final_raw_list)) + metrics[f"{prefix}rewards/dimensions/ref_citation_raw_mean"] = float(np.mean(ref_citation_raw_list)) + metrics[f"{prefix}rewards/dimensions/ref_grounding_raw_mean"] = float(np.mean(ref_grounding_raw_list)) + metrics[f"{prefix}rewards/dimensions/structure_raw_mean"] = float(np.mean(structure_raw_list)) + + # contribution/ weighted contributions + metrics[f"{prefix}rewards/contribution/rm_contribution_mean"] = float(np.mean(rm_contribution_list)) + metrics[f"{prefix}rewards/contribution/ref_contribution_mean"] = float(np.mean(ref_contribution_list)) + metrics[f"{prefix}rewards/contribution/structure_contribution_mean"] = float(np.mean(structure_contribution_list)) + + # Enabled state statistics + ref_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('ref_judge_enabled', False)) + if ref_judge_enabled_count > 0: + metrics[f"{prefix}rewards/ref_judge_enabled_rate"] = ref_judge_enabled_count / n * 100 + + structure_judge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('structure_judge_enabled', False)) + if structure_judge_enabled_count > 0: + metrics[f"{prefix}rewards/structure_judge_enabled_rate"] = structure_judge_enabled_count / n * 100 + + # Time consumption statistics + rm_time_list = [rs.get('rm_time', 0.0) for rs in reward_stats_list] + refstruc_time_list = [rs.get('refstruc_time', 0.0) for rs in reward_stats_list] + + metrics[f"{prefix}judge_time/rm_time_mean"] = float(np.mean(rm_time_list)) + metrics[f"{prefix}judge_time/refstruc_time_mean"] = float(np.mean(refstruc_time_list)) + + if rm_time_list: + metrics[f"{prefix}judge_time/rm_time_max"] = float(np.max(rm_time_list)) + if refstruc_time_list: + metrics[f"{prefix}judge_time/refstruc_time_max"] = float(np.max(refstruc_time_list)) + + # ========== General Time Consumption Statistics ========== + judge_total_time_list = [rs.get('judge_total_time', 0.0) for rs in reward_stats_list] + if any(v != 0.0 for v in judge_total_time_list): + metrics[f"{prefix}judge_time/judge_total_time_mean"] = float(np.mean(judge_total_time_list)) + metrics[f"{prefix}judge_time/judge_total_time_max"] = float(np.max(judge_total_time_list)) + + return metrics + + +def compute_reward_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: + """ + Training phase: Extract reward_stats from trajectories and compute metrics. + + Args: + trajectories: List of trajectory objects + + Returns: + Formatted metrics dictionary + """ + reward_stats_list = extract_reward_stats_from_trajectories(trajectories) + return compute_reward_metrics(reward_stats_list, prefix="train_") + + +def compute_reward_metrics_from_cmts(cmts: List[Any], print_debug: bool = True) -> Dict[str, float]: + """ + Validation phase: Extract reward_stats from cmts and compute metrics. + + Args: + cmts: List of cmt objects + print_debug: Whether to print debug information + + Returns: + Formatted metrics dictionary (with "val_reward/" prefix) + """ + reward_stats_list, debug_stats = extract_reward_stats_from_cmts(cmts) + + if print_debug: + print(f"\n[DEBUG eval_dataset()] reward_stats statistics:") + print(f" - Total cmts count: {debug_stats['total_cmts']}") + print(f" - Has workflow_metadata: {debug_stats['has_workflow_metadata']}") + print(f" - Has reward_stats: {debug_stats['has_reward_stats']}") + print(f" - Extracted samples count: {len(reward_stats_list)}") + + return compute_reward_metrics(reward_stats_list, prefix="val_") diff --git a/ajet/utils/metric_helper/save_trajectory_as_json.py b/ajet/utils/metric_helper/save_trajectory_as_json.py new file mode 100644 index 00000000..0e380abc --- /dev/null +++ b/ajet/utils/metric_helper/save_trajectory_as_json.py @@ -0,0 +1,56 @@ +import os +import json +from ajet.utils.msg_converter import convert_grouped_steps_to_openai_format + + +def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): + """ + Save ctx_trackers to JSON files for either training or evaluation. + + Args: + ctx_trackers (list): List of context trackers containing trajectory data. + global_steps (int): The global step count to organize saved files. + prefix (str): Directory prefix indicating the type of trajectory ("train" or "eval"). + """ + for ctx_tracker in ctx_trackers: + # Determine task tag based on reward + reward = ctx_tracker.reward_structure.raw_reward + if reward >= 1: + ctx_tracker.tag = "success" + elif reward == 0: + ctx_tracker.tag = "failure" + else: + ctx_tracker.tag = "half_success" + + formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.timeline_cache) + + # Prepare trajectory data + traj_data = { + "task_id": ctx_tracker.task_id, + "task_tag": ctx_tracker.tag, + "reward_structure": ctx_tracker.reward_structure.model_dump(), + "traj": formatted_traj + } + + # Extract reward_stats from workflow_metadata if available + if hasattr(ctx_tracker, 'workflow_metadata') and ctx_tracker.workflow_metadata: + if 'reward_stats' in ctx_tracker.workflow_metadata: + traj_data['reward_structure']['reward_stats'] = ctx_tracker.workflow_metadata['reward_stats'] + + # Define save directory and file path + traj_save_dir = os.path.join( + os.environ.get("BEST_LOGGER_PATH", "launcher_record"), + "ctx_trackers", + prefix, + f"step_{global_steps}" + ) + os.makedirs(traj_save_dir, exist_ok=True) + traj_file_path = os.path.join(traj_save_dir, f"{ctx_tracker.task_id}.json") + + # Save trajectory data to JSON file + with open(traj_file_path, "w", encoding="utf-8") as f: + json.dump(traj_data, f, ensure_ascii=False, indent=2) + + # Print confirmation for evaluation trajectories + if prefix != "train": + print(f"Saved trajectory to {traj_file_path}") \ No newline at end of file diff --git a/ajet/utils/metric_helper/tool_metric_helper.py b/ajet/utils/metric_helper/tool_metric_helper.py new file mode 100644 index 00000000..e9c7728d --- /dev/null +++ b/ajet/utils/metric_helper/tool_metric_helper.py @@ -0,0 +1,167 @@ +""" +FinWorld Tool Metrics Helper + +Specialized module for extracting tool-related statistics and formatting SwanLab reports. +Extracts data from workflow_metadata['tool_stats']. + +SwanLab metrics directory structure: +- tool_stats/ Overall statistics (success rate, cache hit rate, etc.) +- tool_time/ Time consumption statistics by tool +- tool_cache/ Cache hit rate by tool +- tool_error/ Error rate by tool +""" + +from typing import List, Dict, Any +import numpy as np + + +def extract_tool_stats_from_trajectories(trajectories: List[Any]) -> List[Dict[str, Any]]: + """ + Extract tool_stats from trajectories list. + + Args: + trajectories: List of trajectory objects containing workflow_metadata + + Returns: + List of tool_stats dictionaries + """ + tool_stats_list = [] + for traj in trajectories: + if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: + if 'tool_stats' in traj.workflow_metadata: + tool_stats_list.append(traj.workflow_metadata['tool_stats']) + return tool_stats_list + + +def extract_tool_stats_from_cmts(cmts: List[Any]) -> List[Dict[str, Any]]: + """ + Extract tool_stats from cmts list. + + Args: + cmts: List of cmt objects containing workflow_metadata + + Returns: + List of tool_stats dictionaries + """ + tool_stats_list = [] + for traj in trajs: + if hasattr(traj, 'workflow_metadata') and traj.workflow_metadata: + if 'tool_stats' in traj.workflow_metadata: + tool_stats_list.append(traj.workflow_metadata['tool_stats']) + return tool_stats_list + + +def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "") -> Dict[str, float]: + """ + Compute SwanLab metrics from tool_stats list. + + Args: + tool_stats_list: List of tool_stats dictionaries + prefix: Metric name prefix (e.g., "val/" for validation phase) + + Returns: + Formatted metrics dictionary ready for SwanLab reporting + """ + if not tool_stats_list: + return {} + + metrics = {} + + # ========== 1. Overall Statistics ========== + total_calls_list = [stats.get('total_calls', 0) for stats in tool_stats_list] + success_calls_list = [stats.get('success_calls', 0) for stats in tool_stats_list] + error_calls_list = [stats.get('total_errors', 0) for stats in tool_stats_list] + cache_hits_list = [stats.get('cache_hits', 0) for stats in tool_stats_list] + cache_misses_list = [stats.get('cache_misses', 0) for stats in tool_stats_list] + + # Calculate overall success rate + total_calls_sum = sum(total_calls_list) + success_calls_sum = sum(success_calls_list) + tool_success_rate = (success_calls_sum / total_calls_sum * 100) if total_calls_sum > 0 else 0.0 + + # Calculate overall cache hit rate + cache_total = sum(cache_hits_list) + sum(cache_misses_list) + cache_hit_rate = (sum(cache_hits_list) / cache_total * 100) if cache_total > 0 else 0.0 + + metrics.update({ + f"{prefix}tool_stats/tool_success_rate": tool_success_rate, + f"{prefix}tool_stats/tool_total_calls": float(np.mean(total_calls_list)), + f"{prefix}tool_stats/tool_success_calls": float(np.mean(success_calls_list)), + f"{prefix}tool_stats/tool_error_calls": float(np.mean(error_calls_list)), + f"{prefix}tool_stats/tool_cache_hit_rate": cache_hit_rate, + f"{prefix}tool_stats/tool_cache_hits": float(np.mean(cache_hits_list)), + f"{prefix}tool_stats/tool_cache_misses": float(np.mean(cache_misses_list)), + }) + + # ========== 2. Time Consumption Statistics by Tool ========== + tool_time_by_name = {} + for stats in tool_stats_list: + tool_time_dict = stats.get('tool_time', {}) + for tool_name, time_list in tool_time_dict.items(): + if tool_name not in tool_time_by_name: + tool_time_by_name[tool_name] = [] + if isinstance(time_list, list): + tool_time_by_name[tool_name].extend(time_list) + + for tool_name, time_list in tool_time_by_name.items(): + if time_list: + metrics[f"{prefix}tool_time/{tool_name}/mean"] = float(np.mean(time_list)) + metrics[f"{prefix}tool_time/{tool_name}/max"] = float(np.max(time_list)) + metrics[f"{prefix}tool_time/{tool_name}/count"] = len(time_list) + + # ========== 3. Cache Hit Rate by Tool ========== + tool_cache_by_name = {} + for stats in tool_stats_list: + tool_cache_stats = stats.get('tool_cache_stats', {}) + for tool_name, cache_info in tool_cache_stats.items(): + if tool_name not in tool_cache_by_name: + tool_cache_by_name[tool_name] = {'hits': 0, 'misses': 0} + tool_cache_by_name[tool_name]['hits'] += cache_info.get('hits', 0) + tool_cache_by_name[tool_name]['misses'] += cache_info.get('misses', 0) + + for tool_name, cache_info in tool_cache_by_name.items(): + hits = cache_info['hits'] + misses = cache_info['misses'] + total = hits + misses + if total > 0: + hit_rate = hits / total * 100 + metrics[f"{prefix}tool_cache/{tool_name}/hit_rate"] = round(hit_rate, 2) + metrics[f"{prefix}tool_cache/{tool_name}/hits"] = hits + metrics[f"{prefix}tool_cache/{tool_name}/misses"] = misses + + # ========== 4. Error Rate by Tool ========== + tool_error_by_name = {} + for stats in tool_stats_list: + tool_error_stats = stats.get('tool_error_stats', {}) + for tool_name, error_info in tool_error_stats.items(): + if tool_name not in tool_error_by_name: + tool_error_by_name[tool_name] = {'calls': 0, 'errors': 0} + tool_error_by_name[tool_name]['calls'] += error_info.get('calls', 0) + tool_error_by_name[tool_name]['errors'] += error_info.get('errors', 0) + + for tool_name, error_info in tool_error_by_name.items(): + calls = error_info['calls'] + errors = error_info['errors'] + if calls > 0: + error_rate = errors / calls * 100 + metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) + metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls + metrics[f"{prefix}tool_error/{tool_name}/errors"] = errors + + return metrics + + +def compute_tool_metrics_from_trajectories(trajectories: List[Any]) -> Dict[str, float]: + """ + Training phase: Extract tool_stats from trajectories and compute metrics. + """ + tool_stats_list = extract_tool_stats_from_trajectories(trajectories) + return compute_tool_metrics(tool_stats_list, prefix="train_") + + +def compute_tool_metrics_from_cmts(cmts: List[Any]) -> Dict[str, float]: + """ + Validation phase: Extract tool_stats from cmts and compute metrics. + """ + tool_stats_list = extract_tool_stats_from_cmts(cmts) + return compute_tool_metrics(tool_stats_list, prefix="val_") diff --git a/ajet/utils/msg_converter.py b/ajet/utils/msg_converter.py new file mode 100644 index 00000000..0437f5ca --- /dev/null +++ b/ajet/utils/msg_converter.py @@ -0,0 +1,104 @@ +""" +Message format conversion utilities + +Provides bidirectional conversion between OpenAI format and AgentScope format. +Unified for both train and val phases. + +## OpenAI format examples: +- Assistant with tool_calls: + {"role": "assistant", "content": "...", "tool_calls": [{"id": "call_xxx", "type": "function", "function": {"name": "...", "arguments": "..."}}]} +- Tool result: + {"role": "tool", "content": "...", "tool_call_id": "call_xxx"} +- Normal message: + {"role": "user/assistant/system", "content": "..."} + +## AgentScope format examples: +- Assistant with tool_calls: + {"role": "assistant", "content": [{"type": "text", "text": "..."}, {"type": "tool_use", "id": "call_xxx", "name": "...", "input": {...}}]} +- Tool result: + {"role": "user", "content": [{"type": "tool_result", "id": "call_xxx", "output": "..."}]} +- Normal message: + {"role": "user/assistant/system", "content": "..."} +""" + +import json +from typing import List, Dict, Any, Union + + + +# ============================================================================= +# ExtendedMessage -> OpenAI conversion (backward compatible functions) +# ============================================================================= + +def convert_ext_msg_to_openai_format(ext_msg: Any) -> Dict[str, Any]: + """ + Convert a single ExtendedMessage or dict to OpenAI format message. + + Args: + ext_msg: ExtendedMessage object or dict + + Returns: + Message dict in OpenAI format + """ + # Helper function: get attribute value + def get_attr(obj, attr_name, default=None): + if hasattr(obj, attr_name): + return getattr(obj, attr_name) + elif isinstance(obj, dict): + return obj.get(attr_name, default) + return default + + # Check if there are tool_calls (assistant initiates tool call) + tool_calls = get_attr(ext_msg, 'tool_calls') + has_tool_calls = bool(tool_calls) + + # Check if there's tool_call_id (tool return result) + tool_call_id = get_attr(ext_msg, 'tool_call_id') + has_tool_call_id = bool(tool_call_id) + + # Get basic attributes + role = get_attr(ext_msg, 'role', 'user') + content = get_attr(ext_msg, 'content', '') + + if has_tool_calls: + # Assistant message contains tool_calls -> keep OpenAI format + msg_dict = { + "role": "assistant", + "content": content if content else "", + "tool_calls": tool_calls + } + elif has_tool_call_id: + # Tool return result -> use OpenAI format (role: "tool") + msg_dict = { + "role": "tool", + "content": content if content else "", + "tool_call_id": tool_call_id + } + else: + # Normal message, keep original format + msg_dict = { + "role": role, + "content": content if content else "" + } + + return msg_dict + + +def convert_grouped_steps_to_openai_format(timelines: List[List[Any]]) -> List[List[Dict[str, Any]]]: + """ + Convert timelines (multi-turn conversation steps) to OpenAI format. + + Args: + timelines: List of List of ExtendedMessage or dict + + Returns: + Trajectory data in OpenAI format (List of List of dict) + """ + formatted_traj = [] + for context in timelines: + step_msgs = [] + for ext_msg in context: + msg_dict = convert_ext_msg_to_openai_format(ext_msg) + step_msgs.append(msg_dict) + formatted_traj.append(step_msgs) + return formatted_traj diff --git a/ajet/utils/pty.py b/ajet/utils/pty.py index a2124775..6d859ae1 100644 --- a/ajet/utils/pty.py +++ b/ajet/utils/pty.py @@ -40,7 +40,8 @@ def master_read(fd): # log_f.write(data.decode()) # log_f.flush() # Also print to stdout (optional) - print(data.decode(), end="") + # Use errors='replace' to handle incomplete UTF-8 sequences + print(data.decode(errors='replace'), end="") return data # Define stdin read callback @@ -85,7 +86,7 @@ def base64_to_string(b): def pty_wrapper( cmd: list[str], dir: str, - env_dict: dict = {}, + env_dict: dict[str, str] = {}, ): run_command_with_pty(cmd, working_dir=dir, env_dict=env_dict) @@ -109,7 +110,7 @@ def pty_launch(service_name: str, success_std_string="Starting server on"): use_pty=True, ) companion.launch( - launch_wait_time=1800, + launch_wait_time=3600, success_std_string=success_std_string, ) diff --git a/tests/bench/README.md b/tests/bench/README.md index b8f0a096..a849d3ca 100644 --- a/tests/bench/README.md +++ b/tests/bench/README.md @@ -14,13 +14,14 @@ Note: `tests/bench` source code is for test robot only, therefore `yaml` configu source .venv/bin/activate python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py +python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py +python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py -python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py -python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py -python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py::TestBenchmarkCountdown::test_01_begin_verl -python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py::TestBenchmarkMath::test_01_begin_verl -python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py::TestBenchmarkAppworld::test_01_begin_verl -python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py::TestBenchmarkLearnToAsk::test_02_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_math/execute_benchmark_math.py::TestBenchmarkMath::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_appworld/execute_benchmark_appworld.py::TestBenchmarkAppworld::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_countdown/execute_benchmark_countdown.py::TestBenchmarkCountdown::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py::TestBenchmarkLearnToAsk::test_01_begin_verl +VERL_PYTHON="./.venv/bin/python" python -m pytest -s tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py::TestBenchmarkFrozenLake::test_01_begin_verl ``` diff --git a/tests/bench/benchmark_appworld/execute_benchmark_appworld.py b/tests/bench/benchmark_appworld/execute_benchmark_appworld.py index df28a7be..eb37a189 100644 --- a/tests/bench/benchmark_appworld/execute_benchmark_appworld.py +++ b/tests/bench/benchmark_appworld/execute_benchmark_appworld.py @@ -8,6 +8,7 @@ class TestBenchmarkAppworld(BenchmarkTestCase): + def test_01_begin_verl(self): # get probe target, so as to get timeout settings BACKBONE = "verl" @@ -16,7 +17,7 @@ def test_01_begin_verl(self): # tests/bench/benchmark_appworld/benchmark_appworld.py # tests/bench/benchmark_appworld/benchmark_appworld.yaml TARGET_NAME = f"benchmark_appworld_{BACKBONE}" - PYTHON_EXECUTABLE = ".verl/bin/python" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") multi_nodes = False self.execute_benchmark( @@ -37,7 +38,7 @@ def test_02_begin_trinity(self): TEST_TARGET = "tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml" PROBE_TARGET = "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" TARGET_NAME = f"benchmark_appworld_{BACKBONE}" - PYTHON_EXECUTABLE = ".venv/bin/python" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") multi_nodes = True self.execute_benchmark( diff --git a/tests/bench/benchmark_countdown/execute_benchmark_countdown.py b/tests/bench/benchmark_countdown/execute_benchmark_countdown.py index 78adbd4d..db74e25a 100644 --- a/tests/bench/benchmark_countdown/execute_benchmark_countdown.py +++ b/tests/bench/benchmark_countdown/execute_benchmark_countdown.py @@ -1,15 +1,20 @@ +import os import unittest from tests.bench.benchmark_base import BenchmarkTestCase + + class TestBenchmarkCountdown(BenchmarkTestCase, unittest.TestCase): + def test_01_begin_verl(self): BACKBONE = "verl" TEST_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.yaml" PROBE_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" TARGET_NAME = f"benchmark_countdown_{BACKBONE}" - PYTHON_EXECUTABLE = ".verl/bin/python" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET, @@ -23,7 +28,8 @@ def test_02_begin_trinity(self): TEST_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.yaml" PROBE_TARGET = "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" TARGET_NAME = f"benchmark_countdown_{BACKBONE}" - PYTHON_EXECUTABLE = ".venv/bin/python" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET, diff --git a/tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py b/tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py index fe78de4a..7de32908 100644 --- a/tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py +++ b/tests/bench/benchmark_frozenlake/execute_benchmark_frozenlake.py @@ -1,15 +1,17 @@ -import unittest +import os from tests.bench.benchmark_base import BenchmarkTestCase class TestBenchmarkFrozenLake(BenchmarkTestCase): - def test_01_begin_trinity(self): - BACKBONE = "trinity" + + def test_01_begin_verl(self): + BACKBONE = "verl" TEST_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml" PROBE_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.py->TestProbe" TARGET_NAME = f"benchmark_frozenlake_{BACKBONE}" - PYTHON_EXECUTABLE = ".venv/bin/python" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET, @@ -18,12 +20,13 @@ def test_01_begin_trinity(self): python_executable=PYTHON_EXECUTABLE, ) - def test_02_begin_verl(self): - BACKBONE = "verl" + def test_02_begin_trinity(self): + BACKBONE = "trinity" TEST_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.yaml" PROBE_TARGET = "tests/bench/benchmark_frozenlake/benchmark_frozenlake.py->TestProbe" TARGET_NAME = f"benchmark_frozenlake_{BACKBONE}" - PYTHON_EXECUTABLE = ".verl/bin/python" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET, diff --git a/tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py b/tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py index 71bc0013..e467e030 100644 --- a/tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py +++ b/tests/bench/benchmark_learn2ask/execute_benchmark_learn2ask.py @@ -1,17 +1,18 @@ -import unittest +import os from tests.bench.benchmark_base import BenchmarkTestCase class TestBenchmarkLearnToAsk(BenchmarkTestCase): - def test_01_begin_trinity(self): + + def test_01_begin_verl(self): # get probe target, so as to get timeout settings - BACKBONE = "trinity" + BACKBONE = "verl" TEST_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml" PROBE_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.py->TestProbe" TARGET_NAME = f"benchmark_learn2ask_{BACKBONE}" - # PYTHON_EXECUTABLE = "python" - PYTHON_EXECUTABLE = ".venv/bin/python" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET, @@ -20,14 +21,14 @@ def test_01_begin_trinity(self): python_executable=PYTHON_EXECUTABLE, ) - def test_02_begin_verl(self): + def test_02_begin_trinity(self): # get probe target, so as to get timeout settings - BACKBONE = "verl" + BACKBONE = "trinity" TEST_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml" PROBE_TARGET = "tests/bench/benchmark_learn2ask/benchmark_learn2ask.py->TestProbe" TARGET_NAME = f"benchmark_learn2ask_{BACKBONE}" - # PYTHON_EXECUTABLE = "python" - PYTHON_EXECUTABLE = ".verl/bin/python" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET, diff --git a/tests/bench/benchmark_math/execute_benchmark_math.py b/tests/bench/benchmark_math/execute_benchmark_math.py index 1feadd04..afd5e41e 100644 --- a/tests/bench/benchmark_math/execute_benchmark_math.py +++ b/tests/bench/benchmark_math/execute_benchmark_math.py @@ -1,14 +1,17 @@ +import os from tests.bench.benchmark_base import BenchmarkTestCase class TestBenchmarkMath(BenchmarkTestCase): - def test_02_begin_trinity(self): + + def test_01_begin_verl(self): # get probe target, so as to get timeout settings - BACKBONE = "trinity" + BACKBONE = "verl" TEST_TARGET = "tests/bench/benchmark_math/benchmark_math.yaml" PROBE_TARGET = "tests/bench/benchmark_math/benchmark_math.py->TestProbe" TARGET_NAME = f"benchmark_math_{BACKBONE}" - PYTHON_EXECUTABLE = ".venv/bin/python" + PYTHON_EXECUTABLE = os.environ.get("VERL_PYTHON", ".verl/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET, @@ -17,13 +20,14 @@ def test_02_begin_trinity(self): python_executable=PYTHON_EXECUTABLE, ) - def test_01_begin_verl(self): + def test_02_begin_trinity(self): # get probe target, so as to get timeout settings - BACKBONE = "verl" + BACKBONE = "trinity" TEST_TARGET = "tests/bench/benchmark_math/benchmark_math.yaml" PROBE_TARGET = "tests/bench/benchmark_math/benchmark_math.py->TestProbe" TARGET_NAME = f"benchmark_math_{BACKBONE}" - PYTHON_EXECUTABLE = ".verl/bin/python" + PYTHON_EXECUTABLE = os.environ.get("TRINITY_PYTHON", ".venv/bin/python") + self.execute_benchmark( backbone=BACKBONE, test_target=TEST_TARGET,