Skip to content

Commit f2cf697

Browse files
committed
feat: add val_only option and markdown output for validation metrics
- Add val_only config to run validation without training - Add val_print_to_markdown_file_path to output validation metrics to file - Rename TGC metrics to task_pass_rate and add task_pass_rate@k for k in [2,4,8,16] - Add std_reward metric - Add assertion to prevent incompatible swarm_mode with val_before_train/val_only
1 parent 2c8e4c1 commit f2cf697

File tree

6 files changed

+33
-9
lines changed

6 files changed

+33
-9
lines changed

ajet/backbone/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
"AjetTaskReader",
1414
]
1515
except ImportError:
16-
logger.info("trinity is not available.")
16+
pass
17+
# logger.info("trinity is not available.")

ajet/backbone/trainer_verl.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,14 +419,23 @@ def fit(self): # noqa: C901
419419
self.checkpoint_manager.update_weights(self.global_steps)
420420
self.checkpoint_manager.sleep_replicas()
421421

422+
# [oc] swarm_mode is not compatible with `val_before_train` and `val_only`
423+
assert not (self.config.ajet.enable_swarm_mode and (self.config.ajet.trainer_common.val_before_train or self.config.ajet.trainer_common.val_only)), \
424+
"swarm_mode is not compatible with `val_before_train` and `val_only`"
425+
426+
422427
# perform validation before training
423428
# currently, we only support validation using the reward_function.
424-
if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_swarm_mode):
429+
if (self.val_reward_fn is not None) and (self.config.ajet.trainer_common.val_before_train) and (not self.config.ajet.enable_swarm_mode):
425430
val_metrics = self._validate()
426431
assert val_metrics, f"{val_metrics=}"
427-
pprint(f"Initial validation metrics: {val_metrics}")
428432
self.verl_logger.log(data=val_metrics, step=self.global_steps)
429-
if self.config.trainer.get("val_only", False):
433+
val_print_to_markdown_file_path = self.config.ajet.trainer_common.val_print_to_markdown_file_path
434+
if val_print_to_markdown_file_path:
435+
with open(val_print_to_markdown_file_path, mode="a+") as f:
436+
f.write(str(val_metrics))
437+
f.write('\n')
438+
if self.config.ajet.trainer_common.val_only:
430439
return
431440

432441
# add tqdm
@@ -983,11 +992,21 @@ def _rollout_val_dataset(self, target_dataset, target_dataset_name, mode, epoch)
983992
"total_tasks": len(task_results),
984993
"num_all_success_tasks": num_all_success_tasks,
985994
f"num_pass_n_tasks(pass@{pass_n})": num_pass_n_tasks,
986-
"TGC@1": repeated_success_tasks / (num_tasks * pass_n),
987-
f"TGC@{pass_n}": num_pass_n_tasks / num_tasks,
988-
f"TGC@{pass_n}-all-pass": num_all_success_tasks / num_tasks,
995+
# [oc]: change var name TGC -> task_pass_rate
996+
"task_pass_rate@1": repeated_success_tasks / (num_tasks * pass_n),
997+
f"task_pass_rate@{pass_n}": num_pass_n_tasks / num_tasks,
998+
f"task_pass_rate@{pass_n}-all-pass": num_all_success_tasks / num_tasks,
989999
"mean_reward": sum(rewards) / len(rewards) if rewards else 0,
1000+
"std_reward": np.std(rewards) if rewards else 0,
9901001
}
1002+
for k in [2, 4, 8, 16]:
1003+
if pass_n > k:
1004+
num_pass_k = 0
1005+
for task_id, task_outcomes in task_results.items():
1006+
if any(tag == "success" for tag in task_outcomes["tag_arr"][:k]):
1007+
num_pass_k += 1
1008+
val_metrics[f"task_pass_rate@{k}"] = num_pass_k / num_tasks
1009+
9911010
save_trajectory_as_json_file(ctx_trackers, self.global_steps, self.config, prefix="eval")
9921011
update_metrics(ctx_trackers, val_metrics, prefix="eval_")
9931012
print_dict(

ajet/default_config/ajet_default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ ajet:
234234
# validation before training
235235
val_before_train: False
236236
val_pass_n: 4
237+
val_only: False
238+
val_print_to_markdown_file_path: null
237239

238240
# save and test frequency (in step)
239241
save_freq: 20

ajet/default_config/verl/config_auto_convertion_verl.jsonc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"ajet.trainer_common.total_epochs": "trainer.total_epochs",
33

44
"ajet.trainer_common.val_before_train": "trainer.val_before_train",
5+
"ajet.trainer_common.val_only": "trainer.val_only",
56
"ajet.trainer_common.n_gpus_per_node": "trainer.n_gpus_per_node",
67
"ajet.trainer_common.nnodes": "trainer.nnodes",
78
"ajet.trainer_common.logger": "trainer.logger",

tests/bench/benchmark_math/benchmark_math.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,16 @@ ajet:
4343
debug_first_n_tasks: 1
4444

4545
trainer_common:
46-
save_freq: 100
46+
save_freq: 999999
4747
test_freq: 999999
4848
total_epochs: 100
4949
logger: swanlab
50+
val_print_to_markdown_file_path: './qwen2-7b.md'
5051
nnodes: 1
5152
n_gpus_per_node: 4
5253
# loss = loss * loss_extra_scale_ratio
5354
loss_extra_scale_ratio: 1.0
55+
val_before_train: true
5456

5557

5658
execute_test: True # DO NOT EDIT, THIS IS FOR TEST ROBOT

tutorial/opencode_build_openclaw_interactive_train/fake_vllm_endpoint.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,6 @@ async def finalize_episodes(task: Task, valid_results: List[EpisodeResult], rewa
299299
lambda ep=episode_result, wo=workflow_output: swarm_client.end_episode(task, ep.episode_uuid, wo),
300300
)
301301

302-
# [oc]: 微调一下代码,当 handle_one2many_request_run_episodes 运行完时,随机stream回去一个答案,但stream不中断,等待reward计算完之后,再结束stream
303302
async def handle_one2many_request(request: Request, request_id: str) -> Dict | List[bytes]:
304303
task, valid_results, all_answers, user_query, all_answers = await handle_one2many_request_run_episodes(request, request_id)
305304
best_answer = await handle_one2many_request_run_reward(task, valid_results, all_answers, user_query)

0 commit comments

Comments
 (0)