Skip to content

Commit f52d83e

Browse files
committed
feat: add trajectory saving feature with config control
- Add save_trajectory.py module with save_train_trajectory and save_eval_trajectory functions - Add save_trajectory config flag in ajet_default.yaml (default: False) - Integrate trajectory saving in trainer_verl.py for both training and evaluation - Extract and save reward_structure, workflow_metadata, and OpenAI-formatted trajectories
1 parent cccc13f commit f52d83e

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from ajet.schema.task import Task
5555
from ajet.task_reader import dict_to_ajet_task
5656
from ajet.task_rollout.native_parallel_worker import VerlRolloutManager
57+
from ajet.utils.save_trajectory import save_train_trajectory, save_eval_trajectory
5758

5859

5960
def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor:
@@ -577,6 +578,9 @@ def fit(self): # noqa: C901
577578
tasks, mode="sample", epoch=f"train.{epoch}"
578579
)
579580
logger.info("=" * 10 + "end fit rollout" + "=" * 10)
581+
582+
if self.config.ajet.trainer_common.save_trajectory:
583+
save_train_trajectory(context_tracker_arr, self.global_steps)
580584
logger.info("begin to convert context_tracker_arr to dataproto")
581585
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
582586
logger.info("end convertion")
@@ -1029,6 +1033,10 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
10291033
for ctx_tracker in ctx_trackers:
10301034
ctx_tracker.generate_log()
10311035

1036+
# save eval trajectories
1037+
if self.config.ajet.trainer_common.save_trajectory:
1038+
save_eval_trajectory(ctx_trackers, self.global_steps)
1039+
10321040
rewards = [ctx_tracker.reward_structure.raw_reward for ctx_tracker in ctx_trackers]
10331041
num_tasks = len(task_results)
10341042
assert num_tasks == len(ctx_trackers) // pass_n

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ ajet:
235235
val_pass_n: 4
236236
save_freq: 20
237237
test_freq: 20
238+
save_trajectory: False # whether to save train/eval trajectories to JSON files
238239
total_epochs: 50
239240
nnodes: 1
240241
n_gpus_per_node: 8

ajet/utils/save_trajectory.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
import json
3+
from ajet.utils.msg_converter import convert_grouped_steps_to_openai_format
4+
5+
6+
def save_train_trajectory(ctx_trackers, global_steps):
7+
"""Save training ctx_trackers to JSON files."""
8+
for ctx_tracker in ctx_trackers:
9+
reward = ctx_tracker.reward_structure.raw_reward
10+
if reward >= 1:
11+
ctx_tracker.tag = "success"
12+
elif reward == 0:
13+
ctx_tracker.tag = "failure"
14+
else:
15+
ctx_tracker.tag = "half_success"
16+
17+
# Use unified conversion function to convert grouped_steps to OpenAI format
18+
if hasattr(ctx_tracker, 'get_grouped_steps_openai_format'):
19+
formatted_traj = ctx_tracker.get_grouped_steps_openai_format()
20+
else:
21+
formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.grouped_steps)
22+
23+
traj_data = {
24+
"task_id": ctx_tracker.task_id,
25+
"task_tag": ctx_tracker.tag,
26+
"reward_structure": ctx_tracker.reward_structure.model_dump(),
27+
"traj": formatted_traj
28+
}
29+
# Extract reward_stats from workflow_metadata
30+
if hasattr(ctx_tracker, 'workflow_metadata') and ctx_tracker.workflow_metadata:
31+
if 'reward_stats' in ctx_tracker.workflow_metadata:
32+
traj_data['reward_structure']['reward_stats'] = ctx_tracker.workflow_metadata['reward_stats']
33+
34+
traj_save_dir = os.path.join(
35+
os.environ.get("BEST_LOGGER_PATH", "launcher_record"),
36+
"ctx_trackers",
37+
"train",
38+
f"step_{global_steps}"
39+
)
40+
os.makedirs(traj_save_dir, exist_ok=True)
41+
traj_file_path = os.path.join(traj_save_dir, f"{ctx_tracker.task_id}.json")
42+
43+
with open(traj_file_path, "w", encoding="utf-8") as f:
44+
json.dump(traj_data, f, ensure_ascii=False, indent=2)
45+
46+
47+
def save_eval_trajectory(ctx_trackers, global_steps):
48+
"""Save evaluation ctx_trackers to JSON files."""
49+
for ctx_tracker in ctx_trackers:
50+
# Use unified conversion function to convert grouped_steps to OpenAI format
51+
if hasattr(ctx_tracker, 'get_grouped_steps_openai_format'):
52+
formatted_traj = ctx_tracker.get_grouped_steps_openai_format()
53+
else:
54+
formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.grouped_steps)
55+
56+
traj_data = {
57+
"task_id": ctx_tracker.task_id,
58+
"task_tag": ctx_tracker.tag,
59+
"reward_structure": ctx_tracker.reward_structure.model_dump(),
60+
"traj": formatted_traj
61+
}
62+
63+
# Extract reward_stats from workflow_metadata
64+
if hasattr(ctx_tracker, 'workflow_metadata') and ctx_tracker.workflow_metadata:
65+
if 'reward_stats' in ctx_tracker.workflow_metadata:
66+
traj_data['reward_structure']['reward_stats'] = ctx_tracker.workflow_metadata['reward_stats']
67+
68+
traj_save_dir = os.path.join(
69+
os.environ.get("BEST_LOGGER_PATH", "launcher_record"),
70+
"ctx_trackers",
71+
"val",
72+
f"step_{global_steps}"
73+
)
74+
os.makedirs(traj_save_dir, exist_ok=True)
75+
traj_file_path = os.path.join(traj_save_dir, f"{ctx_tracker.task_id}.json")
76+
77+
with open(traj_file_path, "w", encoding="utf-8") as f:
78+
json.dump(traj_data, f, ensure_ascii=False, indent=2)
79+
80+
print(f"Saved trajectory to {traj_file_path}")

0 commit comments

Comments
 (0)