Skip to content

Commit 5977031

Browse files
committed
revise commit
1 parent 836e380 commit 5977031

13 files changed

Lines changed: 192 additions & 235 deletions

File tree

ajet/backbone/trainer_verl.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from collections import defaultdict
1818
from pprint import pprint
1919
from typing import List, Optional
20-
from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_cmts
21-
from loguru import logger as loguru_logger
2220

2321
import hydra
2422
import numpy as np
@@ -56,14 +54,7 @@
5654
from ajet.schema.task import Task
5755
from ajet.task_reader import dict_to_ajet_task
5856
from ajet.task_rollout.native_parallel_worker import VerlRolloutManager
59-
from ajet.utils.save_trajectory import save_train_trajectory, save_eval_trajectory
60-
from ajet.utils.msg_converter import (
61-
convert_grouped_steps_to_openai_format,
62-
convert_ext_msg_to_openai_format,
63-
agentscope_to_openai,
64-
openai_to_agentscope,
65-
)
66-
from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_cmts
57+
from ajet.utils.metric_helper import save_trajectory_as_json_file, update_metrics
6758

6859
def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | torch.Tensor:
6960
"""
@@ -586,9 +577,6 @@ def fit(self): # noqa: C901
586577
tasks, mode="sample", epoch=f"train.{epoch}"
587578
)
588579
logger.info("=" * 10 + "end fit rollout" + "=" * 10)
589-
590-
if self.config.ajet.trainer_common.save_trajectory:
591-
save_train_trajectory(context_tracker_arr, self.global_steps)
592580
logger.info("begin to convert context_tracker_arr to dataproto")
593581
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
594582
logger.info("end convertion")
@@ -614,14 +602,8 @@ def fit(self): # noqa: C901
614602
),
615603
}
616604
)
617-
from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_trajectories
618-
from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_trajectories
619-
tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr)
620-
reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr)
621-
if tool_metrics:
622-
metrics.update(tool_metrics)
623-
if reward_metrics:
624-
metrics.update(reward_metrics)
605+
save_trajectory_as_json_file(context_tracker_arr, self.global_steps, self.config, prefix="train")
606+
update_metrics(context_tracker_arr, metrics)
625607
if self.config.ajet.execute_test: # apply a test probe
626608
from swanlab.data.run.main import get_run
627609

@@ -1049,10 +1031,6 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
10491031
for ctx_tracker in ctx_trackers:
10501032
ctx_tracker.generate_log()
10511033

1052-
# save eval trajectories
1053-
if self.config.ajet.trainer_common.save_trajectory:
1054-
save_eval_trajectory(ctx_trackers, self.global_steps)
1055-
10561034
rewards = [ctx_tracker.reward_structure.raw_reward for ctx_tracker in ctx_trackers]
10571035
num_tasks = len(task_results)
10581036
assert num_tasks == len(ctx_trackers) // pass_n
@@ -1068,12 +1046,8 @@ def eval_dataset(self, target_dataset, target_dataset_name, mode, epoch):
10681046
f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks,
10691047
"mean_reward": sum(rewards) / len(rewards) if rewards else 0,
10701048
}
1071-
reward_metrics = compute_reward_metrics_from_cmts(ctx_trackers, print_debug=True)
1072-
tool_metrics = compute_tool_metrics_from_cmts(ctx_trackers)
1073-
if tool_metrics:
1074-
val_metrics.update(reward_metrics)
1075-
if reward_metrics:
1076-
val_metrics.update(tool_metrics)
1049+
save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval")
1050+
update_metrics(ctx_trackers, val_metrics)
10771051
print_dict(
10781052
val_metrics,
10791053
narrow=True,

ajet/context_tracker/base_tracker.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import List, Tuple, Union
22
from typing import List, Union, Tuple, Dict, Optional, Any
3+
from ajet.schema.task import WorkflowTask
34

45
from ajet.schema.extended_msg import (
56
INVALID_LOG_PROB_VALUE,
@@ -111,10 +112,14 @@ def replace_token_ids(
111112

112113

113114
class BaseTracker(object):
114-
def __init__(self, config, tokenizer, **kwargs):
115-
self.task_batch_index = kwargs.get("task_batch_index", "undefined")
116-
self.task_tag = kwargs.get("task_tag", "undefined")
117-
self.task_id = kwargs.get("task_id", "undefined")
115+
def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):
116+
117+
self.workflow_task = workflow_task
118+
self.task_batch_index = self.workflow_task.task_batch_index
119+
self.task_tag = self.workflow_task.task_tag
120+
self.task_id = self.workflow_task.task_id
121+
self.episode_uuid = self.workflow_task.episode_uuid
122+
118123
self.config = config
119124
self.tokenizer = tokenizer
120125
self.saved_timelines: List[List[ExtendedMessage]] = []
@@ -136,7 +141,7 @@ def __init__(self, config, tokenizer, **kwargs):
136141
self.already_mad_flag: bool = False
137142
self.round_cnt = 0
138143
self.generation_prompt_token = None
139-
self.workflow_metadata: Optional[Dict[str, Any]] = None # Initialize workflow_metadata to store tool statistics
144+
self.log_metrics: Optional[Dict[str, Union[float, List[float]]]] = None # Initialize workflow_metadata to store tool statistics
140145

141146
assert (
142147
self.config.ajet.data.max_prompt_length

ajet/context_tracker/multiagent_tracking.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def __init__(
5050
config,
5151
should_interrupt_fn,
5252
generated_token_callback_fn,
53-
episode_uuid: str,
5453
**kwargs,
5554
):
5655
super().__init__(config, tokenizer, **kwargs)
@@ -61,7 +60,6 @@ def __init__(
6160
self.output_kwargs = {}
6261
self.input_kwargs = {}
6362
self.timeline_cache = {}
64-
self.episode_uuid = episode_uuid
6563

6664

6765
def preprocess_tools_field(self, tools: List[dict] = [], disable_toolcalls: bool = False):

ajet/default_config/ajet_default.yaml

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,30 +231,57 @@ ajet:
231231

232232
# trainer common configurations
233233
trainer_common:
234+
235+
# validation before training
234236
val_before_train: False
235237
val_pass_n: 4
238+
239+
# save and test frequency (in step)
236240
save_freq: 20
237241
test_freq: 20
238-
save_trajectory: False # whether to save train/eval trajectories to JSON files
242+
243+
# total training epochs
239244
total_epochs: 50
245+
240246
nnodes: 1
241247
n_gpus_per_node: 8
248+
249+
# logger selection
242250
logger: swanlab
251+
252+
# algorithm setting
243253
algorithm:
244254
adv_estimator: grpo
245255
use_kl_in_reward: False
256+
257+
# number of optimizer.step per big batch
246258
mini_batch_num: 1
259+
260+
# verl offload configs
247261
fsdp_config:
248262
param_offload: True
249263
optimizer_offload: True
264+
265+
# learning rate
250266
optim:
251267
lr: 1e-6
268+
269+
# enable KL loss regularization
252270
use_kl_loss: True
271+
272+
# kl divergence loss coefficient
253273
kl_loss_coef: 0.002
254274
kl_loss_type: low_var_kl
275+
276+
# Ulysses specific configs
255277
ulysses_sequence_parallel_size: 1
278+
279+
# base directory to save checkpoints
256280
checkpoint_base_dir: ./saved_checkpoints
257281

282+
# whether to save train/eval trajectories to JSON files
283+
save_trajectory_as_json_file: False
284+
258285

259286

260287

ajet/launcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from ajet.utils.pty import pty_launch
1717

1818
set_loguru_default_color()
19-
# load_dotenv()
2019
load_dotenv(override=False)
2120

2221

ajet/schema/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,4 @@ class WorkflowOutput(BaseModel):
4343
reward: Union[float, List[float], None] = Field(default=None)
4444
is_success: Union[bool, None] = Field(default=None)
4545
metadata: Dict[str, Any] = Field(default_factory=dict)
46+
log_metrics: Dict[str, Union[float, List[float]]] = Field(default_factory=dict)

ajet/task_rollout/resource_keeper.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,23 @@ def step(self, action: dict) -> Tuple[str, float, bool, dict]:
178178
)
179179
obs = ""
180180
assert isinstance(env_output, dict)
181-
# === Support list-type state passthrough ===
182-
# 1. If state is a list (new standard format), pass through directly
181+
183182
if isinstance(env_output["state"], list):
183+
# 1. If state is a list (new standard format), pass through directly
184184
obs = env_output["state"]
185-
# 2. If state is a dict (old format or error)
186185
else:
186+
# 2. If state is a dict (old format or error)
187187
if ("content" not in env_output["state"]) and ("error" in env_output["state"]):
188188
obs = f"[Error from environment: {env_output['error']}]"
189189
elif env_output["state"].get("content", "") == "":
190190
obs = "Warning: the environment does not provide any feedback, please provide valid inpu and try again."
191191
else:
192192
obs = env_output["state"]["content"]
193+
193194
reward = 0
194195
info = {}
195196
terminate = env_output["is_terminated"]
196-
return obs, reward, terminate, info
197+
return obs, reward, terminate, info # type: ignore
197198

198199
def reset(self) -> str:
199200
"""Reset gym environment."""

ajet/task_runner/general_runner.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ class GeneralRunner(BaseAgentRunner):
1717
def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
1818
observation_window = workflow_task.observation_window
1919
task_thread_index = workflow_task.task_thread_index
20-
task_batch_index = workflow_task.task_batch_index
21-
task_tag = workflow_task.task_tag
22-
task_id = workflow_task.task_id
2320

2421
workflow_import = self.config.ajet.rollout.user_workflow
2522
workflow_cls = dynamic_import(workflow_import)
@@ -34,10 +31,7 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
3431
llm_inference_fn=self.llm_inference_fn,
3532
tokenizer=self.tokenizer,
3633
config=self.config,
37-
task_batch_index=task_batch_index,
38-
task_tag=task_tag,
39-
task_id=task_id,
40-
episode_uuid=workflow_task.episode_uuid,
34+
workflow_task = workflow_task,
4135
**hooks,
4236
)
4337
tuner = AjetTuner(
@@ -46,38 +40,17 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
4640
user_workflow=user_workflow,
4741
config=self.config,
4842
)
49-
5043
workflow_output: WorkflowOutput = asyncio.run(
5144
user_workflow.execute(workflow_task, tuner)
5245
)
53-
# set workflow metadata to context tracker metadata
54-
context_tracker.workflow_metadata = workflow_output.metadata
5546
if workflow_output.reward is not None:
5647
raw_reward, is_success = (
5748
workflow_output.reward,
5849
workflow_output.is_success,
5950
)
6051
else:
6152
raw_reward, is_success = self.get_judge().compute_reward(workflow_task, workflow_output)
62-
63-
# ✅ Critical Fix: After calling `judge`, write the updated `reward_stats` back to `workflow_metadata`
64-
# # Ensure that `native_compat_trainer` reads the actual value calculated by `judge`, not the 0 value returned by `env`.
65-
if workflow_output.metadata and 'reward_stats' in workflow_output.metadata:
66-
context_tracker.workflow_metadata['reward_stats'] = workflow_output.metadata['reward_stats']
67-
else:
68-
# fallback: If the judge does not update reward_stats, use the default value.
69-
logger.warning(f"[WARN] reward_stats not found in metadata after judge call, creating default values")
70-
default_reward_stats = {
71-
'original_reward': raw_reward,
72-
'penalty': 0.0,
73-
'step_reward': raw_reward,
74-
}
75-
if workflow_output.metadata:
76-
workflow_output.metadata['reward_stats'] = default_reward_stats
77-
context_tracker.workflow_metadata['reward_stats'] = default_reward_stats
78-
else:
79-
context_tracker.workflow_metadata = {'reward_stats': default_reward_stats}
80-
53+
8154
workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue
8255

8356
assert not isinstance(
@@ -95,12 +68,11 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
9568
)
9669
context_tracker.process_reward(reward)
9770
# generate token before merging
98-
context_tracker.task_id = task_id
99-
context_tracker.task_tag = task_tag
10071
context_tracker.group_merge()
10172
# after merging, process and align reward again
10273
context_tracker.process_reward(reward)
10374
# mark the thread as ended
10475
observation_window["step"][task_thread_index] = -1
10576
tuner.terminate_episode()
77+
context_tracker.log_metrics = workflow_output.log_metrics
10678
return context_tracker
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from ajet.utils.metric_helper.save_trajectory_as_json import save_trajectory_as_json
2+
from ajet.utils.metric_helper.tool_metric_helper import compute_tool_metrics_from_trajectories
3+
from ajet.utils.metric_helper.reward_metric_helper import compute_reward_metrics_from_trajectories
4+
5+
6+
def save_trajectory_as_json_file(ctx_trackers, global_steps, config, prefix):
7+
if config.ajet.trainer_common.save_trajectory_as_json:
8+
save_trajectory_as_json(ctx_trackers, global_steps, prefix)
9+
10+
def update_metrics(context_tracker_arr, metrics:dict):
11+
tool_metrics = compute_tool_metrics_from_trajectories(context_tracker_arr)
12+
reward_metrics = compute_reward_metrics_from_trajectories(context_tracker_arr)
13+
if tool_metrics:
14+
metrics.update(tool_metrics)
15+
if reward_metrics:
16+
metrics.update(reward_metrics)
17+
return
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import json
3+
from ajet.utils.msg_converter import convert_grouped_steps_to_openai_format
4+
5+
6+
def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"):
7+
"""
8+
Save ctx_trackers to JSON files for either training or evaluation.
9+
10+
Args:
11+
ctx_trackers (list): List of context trackers containing trajectory data.
12+
global_steps (int): The global step count to organize saved files.
13+
prefix (str): Directory prefix indicating the type of trajectory ("train" or "eval").
14+
"""
15+
for ctx_tracker in ctx_trackers:
16+
# Determine task tag based on reward
17+
reward = ctx_tracker.reward_structure.raw_reward
18+
if reward >= 1:
19+
ctx_tracker.tag = "success"
20+
elif reward == 0:
21+
ctx_tracker.tag = "failure"
22+
else:
23+
ctx_tracker.tag = "half_success"
24+
25+
formatted_traj = convert_grouped_steps_to_openai_format(ctx_tracker.timeline_cache)
26+
27+
# Prepare trajectory data
28+
traj_data = {
29+
"task_id": ctx_tracker.task_id,
30+
"task_tag": ctx_tracker.tag,
31+
"reward_structure": ctx_tracker.reward_structure.model_dump(),
32+
"traj": formatted_traj
33+
}
34+
35+
# Extract reward_stats from workflow_metadata if available
36+
if hasattr(ctx_tracker, 'workflow_metadata') and ctx_tracker.workflow_metadata:
37+
if 'reward_stats' in ctx_tracker.workflow_metadata:
38+
traj_data['reward_structure']['reward_stats'] = ctx_tracker.workflow_metadata['reward_stats']
39+
40+
# Define save directory and file path
41+
traj_save_dir = os.path.join(
42+
os.environ.get("BEST_LOGGER_PATH", "launcher_record"),
43+
"ctx_trackers",
44+
prefix,
45+
f"step_{global_steps}"
46+
)
47+
os.makedirs(traj_save_dir, exist_ok=True)
48+
traj_file_path = os.path.join(traj_save_dir, f"{ctx_tracker.task_id}.json")
49+
50+
# Save trajectory data to JSON file
51+
with open(traj_file_path, "w", encoding="utf-8") as f:
52+
json.dump(traj_data, f, ensure_ascii=False, indent=2)
53+
54+
# Print confirmation for evaluation trajectories
55+
if prefix != "train":
56+
print(f"Saved trajectory to {traj_file_path}")

0 commit comments

Comments
 (0)