Skip to content
Open
111 changes: 50 additions & 61 deletions ajet/utils/metric_helper/reward_metric_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
deep_finance Reward Metrics Helper

Provides standalone utility functions for reward_stats extraction and SwanLab metrics formatting.
Decouples deep_finance-specific logic from core code, reducing intrusion into native_compat_trainer.

Data sources:
1. Finance Evaluator (finance_raw, finance_contribution)
2. OpenJudge Graders (openjudge_xxx_raw, openjudge_xxx_contribution)

SwanLab metrics directory structure:
- rewards/ Top-level aggregated scores
- rewards/dimensions/ Raw scores (unweighted)
- rewards/contribution/ Weighted contributions
- rewards/dimensions/ Raw scores (unweighted): finance_raw, openjudge_*_raw
- rewards/contribution/ Weighted contributions: finance_contribution, openjudge_*_contribution
- rewards/openjudge/ OpenJudge grader specific metrics
- judge_time/ Judge time consumption statistics
"""

Expand Down Expand Up @@ -41,9 +45,9 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str
"""
Compute SwanLab metrics from reward_stats list.

Supports two data sources:
1. RM Gallery RewardStats fields (rm_raw, etc.)
2. OpenJudge fields (openjudge_xxx_raw, openjudge_xxx_contribution, etc.)
Data sources:
1. Finance Evaluator (finance_raw, finance_contribution)
2. OpenJudge Graders (openjudge_xxx_raw, openjudge_xxx_contribution)

Args:
reward_stats_list: List of reward_stats dictionaries
Expand Down Expand Up @@ -72,61 +76,46 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str
metrics[f"{prefix}rewards/fused_reward_mean"] = float(np.mean(fused_reward_list))
metrics[f"{prefix}rewards/penalty_mean"] = float(np.mean(penalty_list))
metrics[f"{prefix}rewards/step_reward_mean"] = float(np.mean(step_reward_list))
metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties)
metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0

# ========== OpenJudge Metrics (PresentationQualityGrader, GroundingGrader) ==========
openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('openjudge_enabled', False))

if openjudge_enabled_count > 0:
# OpenJudge graders: presentation_quality, grounding
openjudge_graders = [
"presentation_quality",
"grounding",
"planning",
"audit",
"traceability",
"cgcv"
]

for grader_name in openjudge_graders:
raw_key = f"openjudge_{grader_name}_raw"
contrib_key = f"openjudge_{grader_name}_contribution"

raw_list = [rs.get(raw_key, 0.0) for rs in reward_stats_list]
contrib_list = [rs.get(contrib_key, 0.0) for rs in reward_stats_list]

# Only report when non-zero values exist
if any(v != 0.0 for v in raw_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_raw_mean"] = float(np.mean(raw_list))
if any(v != 0.0 for v in contrib_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_contribution_mean"] = float(np.mean(contrib_list))

# OpenJudge time consumption statistics
grading_time_list = [rs.get('grading_time', 0.0) for rs in reward_stats_list]
if any(v != 0.0 for v in grading_time_list):
metrics[f"{prefix}judge_time/openjudge_grading_time_mean"] = float(np.mean(grading_time_list))
metrics[f"{prefix}judge_time/openjudge_grading_time_max"] = float(np.max(grading_time_list))

# ========== RM Gallery Metrics ==========

# RM Gallery
rm_raw_list = [rs.get('rm_raw', 0.0) for rs in reward_stats_list]
rm_contribution_list = [rs.get('rm_contribution', 0.0) for rs in reward_stats_list]

# dimensions/ raw scores
metrics[f"{prefix}rewards/dimensions/rm_raw_mean"] = float(np.mean(rm_raw_list))

# contribution/ weighted contributions
metrics[f"{prefix}rewards/contribution/rm_contribution_mean"] = float(np.mean(rm_contribution_list))


# Time consumption statistics
rm_time_list = [rs.get('rm_time', 0.0) for rs in reward_stats_list]
metrics[f"{prefix}judge_time/rm_time_mean"] = float(np.mean(rm_time_list))

if rm_time_list:
metrics[f"{prefix}judge_time/rm_time_max"] = float(np.max(rm_time_list))
metrics[f"{prefix}rewards/penalty_count"] = float(len(non_zero_penalties))
metrics[f"{prefix}rewards/penalty_rate"] = float(len(non_zero_penalties) / n * 100) if n > 0 else 0.0

# ========== OpenJudge Metrics ==========
# OpenJudge graders: presentation_quality, grounding, audit
openjudge_graders = [
"presentation_quality",
"grounding",
"planning",
"audit",
]

for grader_name in openjudge_graders:
raw_key = f"openjudge_{grader_name}_raw"
contrib_key = f"openjudge_{grader_name}_contribution"

raw_list = [rs.get(raw_key, 0.0) for rs in reward_stats_list]
contrib_list = [rs.get(contrib_key, 0.0) for rs in reward_stats_list]

# Only report when non-zero values exist
if any(v != 0.0 for v in raw_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_raw_mean"] = float(np.mean(raw_list))
if any(v != 0.0 for v in contrib_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_contribution_mean"] = float(np.mean(contrib_list))

# OpenJudge time consumption statistics
grading_time_list = [rs.get('grading_time', 0.0) for rs in reward_stats_list]
if any(v != 0.0 for v in grading_time_list):
metrics[f"{prefix}judge_time/openjudge_grading_time_mean"] = float(np.mean(grading_time_list))
metrics[f"{prefix}judge_time/openjudge_grading_time_max"] = float(np.max(grading_time_list))

# ========== Finance Evaluator Metrics ==========
finance_raw_list = [rs.get('finance_raw', 0.0) for rs in reward_stats_list]
finance_contribution_list = [rs.get('finance_contribution', 0.0) for rs in reward_stats_list]

if any(v != 0.0 for v in finance_raw_list):
metrics[f"{prefix}rewards/dimensions/finance_raw_mean"] = float(np.mean(finance_raw_list))

if any(v != 0.0 for v in finance_contribution_list):
metrics[f"{prefix}rewards/contribution/finance_contribution_mean"] = float(np.mean(finance_contribution_list))

# ========== General Time Consumption Statistics ==========
judge_total_time_list = [rs.get('judge_total_time', 0.0) for rs in reward_stats_list]
Expand Down
28 changes: 28 additions & 0 deletions tutorial/example_deep_finance/.env_example
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# API keys
OPENAI_API_KEY="sk-xxx"
OPENAI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
RM_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
RM_API_KEY="sk-xxx"
OPENJUDGE_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
OPENJUDGE_API_KEY="sk-xxx"
STRONG_MODEL_API_KEY="sk-xxx"

SWANLAB_API_KEY="xxx"

# data path, save path
ENV_SERVICE_ROOT="/path/to/env_service"
CONDA_PATH="/path/to/conda/conda.sh"
MODEL_PATH="/path/to/base_model"
CKPT_SAVE_PATH="/path/to/ckpt_path"
# 新增:数据文件路径配置
TRAIN_DATA_PATH="/path/to/train_data"
VAL_DATA_PATH="/path/to/val_data"


TRAIN_REF_ANS_PATH="/path/to/train_reference_answer"
VAL_REF_ANS_PATH="/path/to/val_reference_answer"


# Port
ADDR=""
MCP_PORT=""
Loading
Loading