1717from collections import defaultdict
1818from pprint import pprint
1919from typing import List , Optional
20- from ajet .utils .metric_helper .reward_metric_helper import compute_reward_metrics_from_cmts
21- from loguru import logger as loguru_logger
2220
2321import hydra
2422import numpy as np
5654from ajet .schema .task import Task
5755from ajet .task_reader import dict_to_ajet_task
5856from ajet .task_rollout .native_parallel_worker import VerlRolloutManager
59- from ajet .utils .save_trajectory import save_train_trajectory , save_eval_trajectory
60- from ajet .utils .msg_converter import (
61- convert_grouped_steps_to_openai_format ,
62- convert_ext_msg_to_openai_format ,
63- agentscope_to_openai ,
64- openai_to_agentscope ,
65- )
66- from ajet .utils .metric_helper .tool_metric_helper import compute_tool_metrics_from_cmts
57+ from ajet .utils .metric_helper import save_trajectory_as_json_file , update_metrics
6758
6859def parse_reward_from_dataproto (data : DataProto , return_dict = False ) -> dict | torch .Tensor :
6960 """
@@ -586,9 +577,6 @@ def fit(self): # noqa: C901
586577 tasks , mode = "sample" , epoch = f"train.{ epoch } "
587578 )
588579 logger .info ("=" * 10 + "end fit rollout" + "=" * 10 )
589-
590- if self .config .ajet .trainer_common .save_trajectory :
591- save_train_trajectory (context_tracker_arr , self .global_steps )
592580 logger .info ("begin to convert context_tracker_arr to dataproto" )
593581 gen_batch_output = self .parallel_env .to_dataproto (context_tracker_arr )
594582 logger .info ("end convertion" )
@@ -614,14 +602,8 @@ def fit(self): # noqa: C901
614602 ),
615603 }
616604 )
617- from ajet .utils .metric_helper .tool_metric_helper import compute_tool_metrics_from_trajectories
618- from ajet .utils .metric_helper .reward_metric_helper import compute_reward_metrics_from_trajectories
619- tool_metrics = compute_tool_metrics_from_trajectories (context_tracker_arr )
620- reward_metrics = compute_reward_metrics_from_trajectories (context_tracker_arr )
621- if tool_metrics :
622- metrics .update (tool_metrics )
623- if reward_metrics :
624- metrics .update (reward_metrics )
605+ save_trajectory_as_json_file (context_tracker_arr , self .global_steps , self .config , prefix = "train" )
606+ update_metrics (context_tracker_arr , metrics )
625607 if self .config .ajet .execute_test : # apply a test probe
626608 from swanlab .data .run .main import get_run
627609
@@ -1049,10 +1031,6 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
10491031 for ctx_tracker in ctx_trackers :
10501032 ctx_tracker .generate_log ()
10511033
1052- # save eval trajectories
1053- if self .config .ajet .trainer_common .save_trajectory :
1054- save_eval_trajectory (ctx_trackers , self .global_steps )
1055-
10561034 rewards = [ctx_tracker .reward_structure .raw_reward for ctx_tracker in ctx_trackers ]
10571035 num_tasks = len (task_results )
10581036 assert num_tasks == len (ctx_trackers ) // pass_n
@@ -1068,12 +1046,8 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
10681046 f"TGC@{ pass_n } -all-pass" : num_all_success_tasks / num_tasks ,
10691047 "mean_reward" : sum (rewards ) / len (rewards ) if rewards else 0 ,
10701048 }
1071- reward_metrics = compute_reward_metrics_from_cmts (ctx_trackers , print_debug = True )
1072- tool_metrics = compute_tool_metrics_from_cmts (ctx_trackers )
1073- if tool_metrics :
1074- val_metrics .update (reward_metrics )
1075- if reward_metrics :
1076- val_metrics .update (tool_metrics )
1049+ save_trajectory_as_json_file (ctx_trackers , self .global_steps , self .config , prefix = "eval" )
1050+ update_metrics (ctx_trackers , val_metrics )
10771051 print_dict (
10781052 val_metrics ,
10791053 narrow = True ,
0 commit comments