Skip to content

Commit 7eb9a70

Browse files
committed
feat: enhance configuration for LoRA support and validation checks
1 parent 09f00fb commit 7eb9a70

File tree

6 files changed

+73
-20
lines changed

6 files changed

+73
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,4 @@ tutorial/**/*.json
177177
node_modules
178178
.agents
179179
skills-lock.json
180+
blueprint*

ajet/default_config/ajet_default.yaml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,11 @@ ajet:
276276
betas:
277277
- 0.9
278278
- 0.999
279-
clip_grad: 1.0
280279
min_lr_ratio: 0.0
281280
num_cycles: 0.5
282281
lr_scheduler_type: constant
283282
zero_indexed_step: true
284-
warmup_style: null
285-
override_optimizer_config: null
283+
grad_clip: 20.0
286284

287285
# enable KL loss regularization
288286
use_kl_loss: True
@@ -303,6 +301,13 @@ ajet:
303301
# whether to save train/eval trajectories to JSON files
304302
save_trajectory_as_json_file: False
305303

304+
lora:
305+
# LoRA configuration (disabled by default, set lora_rank > 0 to enable)
306+
lora_rank: 0
307+
lora_alpha: 16
308+
target_modules: all-linear
309+
load_format: auto
310+
306311

307312
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
308313
enable_swarm_mode: False

ajet/default_config/verl/config_auto_convertion_verl.jsonc

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,30 @@
1111
"ajet.trainer_common.algorithm.use_kl_in_reward": "algorithm.use_kl_in_reward",
1212
"ajet.trainer_common.mini_batch_num": "actor_rollout_ref.actor.override_ppo_mini_batch_num",
1313
"ajet.trainer_common.fsdp_config": "actor_rollout_ref.actor.fsdp_config",
14-
"ajet.trainer_common.optim": "actor_rollout_ref.actor.optim",
14+
"ajet.trainer_common.optim.optimizer": "actor_rollout_ref.actor.optim.optimizer",
15+
"ajet.trainer_common.optim.optimizer_impl": "actor_rollout_ref.actor.optim.optimizer_impl",
16+
"ajet.trainer_common.optim.lr": "actor_rollout_ref.actor.optim.lr",
17+
"ajet.trainer_common.optim.lr_warmup_steps_ratio": "actor_rollout_ref.actor.optim.lr_warmup_steps_ratio",
18+
"ajet.trainer_common.optim.total_training_steps": "actor_rollout_ref.actor.optim.total_training_steps",
19+
"ajet.trainer_common.optim.weight_decay": "actor_rollout_ref.actor.optim.weight_decay",
20+
"ajet.trainer_common.optim.lr_warmup_steps": "actor_rollout_ref.actor.optim.lr_warmup_steps",
21+
"ajet.trainer_common.optim.betas": "actor_rollout_ref.actor.optim.betas",
22+
"ajet.trainer_common.optim.min_lr_ratio": "actor_rollout_ref.actor.optim.min_lr_ratio",
23+
"ajet.trainer_common.optim.num_cycles": "actor_rollout_ref.actor.optim.num_cycles",
24+
"ajet.trainer_common.optim.lr_scheduler_type": "actor_rollout_ref.actor.optim.lr_scheduler_type",
25+
"ajet.trainer_common.optim.zero_indexed_step": "actor_rollout_ref.actor.optim.zero_indexed_step",
26+
"ajet.trainer_common.optim.grad_clip": "actor_rollout_ref.actor.optim.grad_clip",
1527
"ajet.trainer_common.use_kl_loss": "actor_rollout_ref.actor.use_kl_loss",
1628
"ajet.trainer_common.kl_loss_coef": "actor_rollout_ref.actor.kl_loss_coef",
1729
"ajet.trainer_common.kl_loss_type": "actor_rollout_ref.actor.kl_loss_type",
1830
"ajet.trainer_common.ulysses_sequence_parallel_size": "actor_rollout_ref.actor.ulysses_sequence_parallel_size",
1931
"ajet.trainer_common.loss_extra_scale_ratio": "actor_rollout_ref.actor.loss_extra_scale_ratio",
2032

33+
"ajet.lora.lora_rank": "actor_rollout_ref.model.lora_rank",
34+
"ajet.lora.lora_alpha": "actor_rollout_ref.model.lora_alpha",
35+
"ajet.lora.target_modules": "actor_rollout_ref.model.target_modules",
36+
"ajet.lora.load_format": "actor_rollout_ref.rollout.load_format",
37+
2138
"ajet.trainer_common.save_freq": "trainer.save_freq",
2239
"ajet.trainer_common.test_freq": "trainer.test_freq",
2340

ajet/default_config/verl/verl_default.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,6 @@ critic:
488488
betas:
489489
- 0.9
490490
- 0.999
491-
clip_grad: 1.0
492491
min_lr_ratio: 0.0
493492
num_cycles: 0.5
494493
lr_scheduler_type: constant

ajet/utils/config_utils.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,41 @@ def expand_ajet_hierarchical_config(config, write_to=None):
278278
return config_final
279279

280280

281+
def _validate_input_yaml_no_overlap_with_auto_convertion_config(input_yaml_config, config_final):
282+
"""Validate that input yaml doesn't contain keys that will be auto-converted with different values."""
283+
import json
284+
import re
285+
286+
jsonc_path = os.path.join(os.path.dirname(__file__), "..", "default_config", "verl", "config_auto_convertion_verl.jsonc")
287+
with open(jsonc_path, "r", encoding="utf-8") as f:
288+
content = f.read()
289+
content = re.sub(r'//.*', '', content)
290+
convertion_json = json.loads(content)
291+
292+
errors = []
293+
for from_key, to_keys in convertion_json.items():
294+
to_keys = to_keys if isinstance(to_keys, list) else [to_keys]
295+
for to_key in to_keys:
296+
try:
297+
input_value = _dive_to_fetch_value(input_yaml_config, to_key)
298+
except ValueError:
299+
continue
300+
final_value = _dive_to_fetch_value(config_final, to_key)
301+
if str(input_value) != str(final_value):
302+
errors.append(
303+
f" - Key '{to_key}': input_yaml value = {input_value}, "
304+
f"but ajet config sets it to = {final_value}"
305+
)
306+
307+
if errors:
308+
error_msg = (
309+
"We found a configuration conflict between AgentJet and Verl! Input yaml contains keys that conflict with ajet default config values:\n"
310+
+ "\n".join(errors)
311+
+ "\nPlease use ajet.xxx to assign training parameters instead."
312+
)
313+
raise ValueError(error_msg)
314+
315+
281316
def prepare_experiment_config(yaml_path, exp_base_dir, backbone, override_param_callback=None, storage=True):
282317
"""
283318
Prepare experiment configuration by reading YAML, setting up backup directories,
@@ -299,7 +334,7 @@ def prepare_experiment_config(yaml_path, exp_base_dir, backbone, override_param_
299334

300335
## 0. read yaml & get experiment_name
301336
with open(yaml_path, "r", encoding="utf-8") as file:
302-
config = yaml.safe_load(file)
337+
config = input_yaml_config = yaml.safe_load(file)
303338
try:
304339
exp_name = config.get("ajet").get("experiment_name")
305340
except Exception:
@@ -367,6 +402,8 @@ def prepare_experiment_config(yaml_path, exp_base_dir, backbone, override_param_
367402
)
368403
config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst)
369404

405+
_validate_input_yaml_no_overlap_with_auto_convertion_config(input_yaml_config, config_final)
406+
370407
if not storage:
371408
shutil.rmtree(os.path.join(exp_base_dir, exp_name))
372409

tutorial/example_math_lora/math_agent.yaml

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# ------------------ main configuration ------------------
22
ajet:
3-
project_name: example_math_agent
3+
project_name: test_lora
44
task_reader:
55
type: huggingface_dat_repo # ✨✨✨✨ `env_service` or `dataset_file` or `huggingface_dat_repo`
66
# effective when `type: huggingface_dat_repo`
@@ -44,35 +44,29 @@ ajet:
4444
max_prompt_length: 3000
4545
max_response_length: 7000
4646

47+
execute_test: false
48+
4749
debug:
4850
debug_max_parallel: 1
4951
debug_first_n_tasks: 1
5052

5153
trainer_common:
54+
val_print_to_markdown_file_path: /mnt/data_cpfs/qingxu.fu/autoresearch-rl/exp_result/hello-agentjet-math-lora/val_result.md
55+
train_print_to_markdown_file_path: /mnt/data_cpfs/qingxu.fu/autoresearch-rl/exp_result/hello-agentjet-math-lora/train_result.md
5256
save_freq: 100
5357
test_freq: 100
5458
total_epochs: 100
5559
logger: swanlab
5660
val_before_train: true
61+
optim:
62+
lr: 3e-05
5763

58-
actor_rollout_ref:
59-
model:
64+
lora:
6065
lora_rank: 32
6166
lora_alpha: 32
6267
target_modules: all-linear
63-
actor:
64-
optim:
65-
lr: 3e-5
66-
fsdp_config:
67-
param_offload: true
68-
optimizer_offload: true
69-
rollout:
7068
load_format: safetensors
7169

72-
trinity:
73-
synchronizer:
74-
sync_offset: 1
75-
sync_method: nccl
7670

7771

7872
# ------------------ do not modify ------------------

0 commit comments

Comments
 (0)