Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 50 additions & 61 deletions ajet/utils/metric_helper/reward_metric_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
deep_finance Reward Metrics Helper

Provides standalone utility functions for reward_stats extraction and SwanLab metrics formatting.
Decouples deep_finance-specific logic from core code, reducing intrusion into native_compat_trainer.

Data sources:
1. Finance Evaluator (finance_raw, finance_contribution)
2. OpenJudge Graders (openjudge_xxx_raw, openjudge_xxx_contribution)

SwanLab metrics directory structure:
- rewards/ Top-level aggregated scores
- rewards/dimensions/ Raw scores (unweighted)
- rewards/contribution/ Weighted contributions
- rewards/dimensions/ Raw scores (unweighted): finance_raw, openjudge_*_raw
- rewards/contribution/ Weighted contributions: finance_contribution, openjudge_*_contribution
- rewards/openjudge/ OpenJudge grader specific metrics
- judge_time/ Judge time consumption statistics
"""

Expand Down Expand Up @@ -41,9 +45,9 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str
"""
Compute SwanLab metrics from reward_stats list.

Supports two data sources:
1. RM Gallery RewardStats fields (rm_raw, etc.)
2. OpenJudge fields (openjudge_xxx_raw, openjudge_xxx_contribution, etc.)
Data sources:
1. Finance Evaluator (finance_raw, finance_contribution)
2. OpenJudge Graders (openjudge_xxx_raw, openjudge_xxx_contribution)

Args:
reward_stats_list: List of reward_stats dictionaries
Expand Down Expand Up @@ -72,61 +76,46 @@ def compute_reward_metrics(reward_stats_list: List[Dict[str, Any]], prefix: str
metrics[f"{prefix}rewards/fused_reward_mean"] = float(np.mean(fused_reward_list))
metrics[f"{prefix}rewards/penalty_mean"] = float(np.mean(penalty_list))
metrics[f"{prefix}rewards/step_reward_mean"] = float(np.mean(step_reward_list))
metrics[f"{prefix}rewards/penalty_count"] = len(non_zero_penalties)
metrics[f"{prefix}rewards/penalty_rate"] = len(non_zero_penalties) / n * 100 if n > 0 else 0.0

# ========== OpenJudge Metrics (PresentationQualityGrader, GroundingGrader) ==========
openjudge_enabled_count = sum(1 for rs in reward_stats_list if rs.get('openjudge_enabled', False))

if openjudge_enabled_count > 0:
# OpenJudge graders: presentation_quality, grounding
openjudge_graders = [
"presentation_quality",
"grounding",
"planning",
"audit",
"traceability",
"cgcv"
]

for grader_name in openjudge_graders:
raw_key = f"openjudge_{grader_name}_raw"
contrib_key = f"openjudge_{grader_name}_contribution"

raw_list = [rs.get(raw_key, 0.0) for rs in reward_stats_list]
contrib_list = [rs.get(contrib_key, 0.0) for rs in reward_stats_list]

# Only report when non-zero values exist
if any(v != 0.0 for v in raw_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_raw_mean"] = float(np.mean(raw_list))
if any(v != 0.0 for v in contrib_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_contribution_mean"] = float(np.mean(contrib_list))

# OpenJudge time consumption statistics
grading_time_list = [rs.get('grading_time', 0.0) for rs in reward_stats_list]
if any(v != 0.0 for v in grading_time_list):
metrics[f"{prefix}judge_time/openjudge_grading_time_mean"] = float(np.mean(grading_time_list))
metrics[f"{prefix}judge_time/openjudge_grading_time_max"] = float(np.max(grading_time_list))

# ========== RM Gallery Metrics ==========

# RM Gallery
rm_raw_list = [rs.get('rm_raw', 0.0) for rs in reward_stats_list]
rm_contribution_list = [rs.get('rm_contribution', 0.0) for rs in reward_stats_list]

# dimensions/ raw scores
metrics[f"{prefix}rewards/dimensions/rm_raw_mean"] = float(np.mean(rm_raw_list))

# contribution/ weighted contributions
metrics[f"{prefix}rewards/contribution/rm_contribution_mean"] = float(np.mean(rm_contribution_list))


# Time consumption statistics
rm_time_list = [rs.get('rm_time', 0.0) for rs in reward_stats_list]
metrics[f"{prefix}judge_time/rm_time_mean"] = float(np.mean(rm_time_list))

if rm_time_list:
metrics[f"{prefix}judge_time/rm_time_max"] = float(np.max(rm_time_list))
metrics[f"{prefix}rewards/penalty_count"] = float(len(non_zero_penalties))
metrics[f"{prefix}rewards/penalty_rate"] = float(len(non_zero_penalties) / n * 100) if n > 0 else 0.0

# ========== OpenJudge Metrics ==========
# OpenJudge graders: presentation_quality, grounding, audit
openjudge_graders = [
"presentation_quality",
"grounding",
"planning",
"audit",
]

for grader_name in openjudge_graders:
raw_key = f"openjudge_{grader_name}_raw"
contrib_key = f"openjudge_{grader_name}_contribution"

raw_list = [rs.get(raw_key, 0.0) for rs in reward_stats_list]
contrib_list = [rs.get(contrib_key, 0.0) for rs in reward_stats_list]

# Only report when non-zero values exist
if any(v != 0.0 for v in raw_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_raw_mean"] = float(np.mean(raw_list))
if any(v != 0.0 for v in contrib_list):
metrics[f"{prefix}rewards/openjudge/{grader_name}_contribution_mean"] = float(np.mean(contrib_list))

# OpenJudge time consumption statistics
grading_time_list = [rs.get('grading_time', 0.0) for rs in reward_stats_list]
if any(v != 0.0 for v in grading_time_list):
metrics[f"{prefix}judge_time/openjudge_grading_time_mean"] = float(np.mean(grading_time_list))
metrics[f"{prefix}judge_time/openjudge_grading_time_max"] = float(np.max(grading_time_list))

# ========== Finance Evaluator Metrics ==========
finance_raw_list = [rs.get('finance_raw', 0.0) for rs in reward_stats_list]
finance_contribution_list = [rs.get('finance_contribution', 0.0) for rs in reward_stats_list]

if any(v != 0.0 for v in finance_raw_list):
metrics[f"{prefix}rewards/dimensions/finance_raw_mean"] = float(np.mean(finance_raw_list))

if any(v != 0.0 for v in finance_contribution_list):
metrics[f"{prefix}rewards/contribution/finance_contribution_mean"] = float(np.mean(finance_contribution_list))

# ========== General Time Consumption Statistics ==========
judge_total_time_list = [rs.get('judge_total_time', 0.0) for rs in reward_stats_list]
Expand Down
28 changes: 28 additions & 0 deletions tutorial/example_deep_finance/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# API keys
OPENAI_API_KEY="sk-xxx"
OPENAI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
RM_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
RM_API_KEY="sk-xxx"
OPENJUDGE_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
OPENJUDGE_API_KEY="sk-xxx"
STRONG_MODEL_API_KEY="sk-xxx"

SWANLAB_API_KEY="xxx"

# data path, save path
ENV_SERVICE_ROOT="/path/to/env_service"
CONDA_PATH="/path/to/conda/conda.sh"
MODEL_PATH="/path/to/base_model"
CKPT_SAVE_PATH="/path/to/ckpt_path"
# 新增:数据文件路径配置
TRAIN_DATA_PATH="/path/to/train_data"
VAL_DATA_PATH="/path/to/val_data"


TRAIN_REF_ANS_PATH="/path/to/train_reference_answer"
VAL_REF_ANS_PATH="/path/to/val_reference_answer"


# Port
ADDR=""
MCP_PORT=""
Loading
Loading