Skip to content

Commit 3cc6e58

Browse files
agent framework abl study
1 parent c68c154 commit 3cc6e58

12 files changed

Lines changed: 1210 additions & 18 deletions

File tree

ajet/backbone/verl/dp_actor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@ def update_policy(self, data: DataProto):
138138
# make sure we are in training mode
139139
self.actor_module.train()
140140

141+
# [AJET] Optional: estimate the GPU-memory limit of ppo_max_token_len_per_gpu and raise.
142+
# Triggered by env AGENTJET_FIND_MAX_PPO_TOKEN_LEN. Intercepts the *first* real PPO update
143+
# (model + grads + optimizer already resident, so the measurement is realistic). See
144+
# ajet.utils.find_max_ppo_token_len. It never returns.
145+
if os.environ.get("AGENTJET_FIND_MAX_PPO_TOKEN_LEN"):
146+
from ajet.utils.find_max_ppo_token_len import find_max_ppo_token_len_per_gpu
147+
find_max_ppo_token_len_per_gpu(self, data)
148+
141149
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid silent error
142150
pad_token_id = data.meta_info.get("pad_token_id", 0)
143151

ajet/copilot/job.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def __init__(
127127
max_model_len: int | None = None,
128128
tensor_model_parallel_size: int | None = None,
129129
max_num_seqs: int | None = None,
130+
ppo_max_token_len_per_gpu: int | None = None,
131+
ulysses_sequence_parallel_size: int | None = None,
130132
mini_batch_num: int | None = None,
131133
lora_rank: int | None = None,
132134
lora_alpha: int | None = None,
@@ -195,6 +197,8 @@ def __init__(
195197
self.max_model_len: int = cast(int, max_model_len)
196198
self.tensor_model_parallel_size: int = cast(int, tensor_model_parallel_size)
197199
self.max_num_seqs: int = cast(int, max_num_seqs)
200+
self.ppo_max_token_len_per_gpu: int | None = ppo_max_token_len_per_gpu
201+
self.ulysses_sequence_parallel_size: int | None = ulysses_sequence_parallel_size
198202
self.mini_batch_num: int = cast(int, mini_batch_num)
199203
self.lora_rank: int = cast(int, lora_rank)
200204
self.lora_alpha: int = cast(int, lora_alpha)
@@ -236,6 +240,8 @@ def __init__(
236240
"ajet.rollout.max_model_len": "max_model_len",
237241
"ajet.rollout.tensor_model_parallel_size": "tensor_model_parallel_size",
238242
"ajet.rollout.max_num_seqs": "max_num_seqs",
243+
"ajet.rollout.ppo_max_token_len_per_gpu": "ppo_max_token_len_per_gpu",
244+
"ajet.trainer_common.ulysses_sequence_parallel_size": "ulysses_sequence_parallel_size",
239245
"ajet.trainer_common.mini_batch_num": "mini_batch_num",
240246
"ajet.lora.lora_rank": "lora_rank",
241247
"ajet.lora.lora_alpha": "lora_alpha",

ajet/default_config/ajet_config_schema.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
from dataclasses import dataclass, field
6-
from typing import Any, Dict, List
6+
from typing import Any, Dict, List, Optional
77

88

99
@dataclass
@@ -26,6 +26,7 @@ class AjetTrainerCommon:
2626
use_kl_in_reward: bool = False
2727
kl_penalty_type: str = "kl"
2828
ppo_epochs: int = 1
29+
ulysses_sequence_parallel_size: int = 1
2930
val_print_to_markdown_file_path: str | None = None
3031
train_print_to_markdown_file_path: str | None = None
3132
total_training_steps: int | None = None
@@ -34,20 +35,9 @@ class AjetTrainerCommon:
3435
total_epochs: int = 50
3536
val_pass_n: int = 4
3637
val_before_train: bool = False
37-
# When enabled, every sample produced by the same episode (same
38-
# non_tensor_batch["episode_uuids"]) gets its loss weight multiplied by
39-
# 1/N (N = number of samples in that episode) so each episode contributes
40-
# equally to the policy-gradient update regardless of how many samples it
41-
# generated. Disabled by default (current behaviour: every sample weighted
42-
# equally).
38+
# When enabled, every sample produced by the same episode (same non_tensor_batch["episode_uuids"]) gets its loss weight multiplied by 1/N (N = number of samples in that episode)
4339
loss_weight_normalization_episode_level: bool = False
44-
# When enabled, GRPO group statistics (baseline mean / std) are computed at
45-
# episode scope instead of sample scope: each episode (same
46-
# non_tensor_batch["episode_uuids"]) is first reduced to its mean reward,
47-
# then the per-task (same non_tensor_batch["uid"]) baseline is the mean over
48-
# those episode means. This makes every episode contribute equally to the
49-
# advantage baseline regardless of how many samples it generated. Disabled
50-
# by default (current behaviour: baseline is the mean over all samples).
40+
# When enabled, GRPO group statistics (baseline mean / std) are computed at episode scope instead of sample scope
5141
advantage_estimation_episode_level: bool = False
5242

5343
@dataclass
@@ -71,6 +61,8 @@ class AjetRollout:
7161
max_num_seqs: int = 64
7262
num_repeat: int = 8
7363
gpu_memory_utilization: float = 0.85
64+
# Per-GPU token budget for the PPO actor update (dynamic batching).
65+
ppo_max_token_len_per_gpu: Optional[int] = None # None => track ajet.rollout.max_model_len (legacy behaviour).
7466
compute_madness_checklist: List[str] = field(default_factory=list)
7567

7668

ajet/default_config/ajet_default.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ ajet:
4949
# max token length allowed for the model during rollout
5050
max_model_len: 18000
5151

52+
# Per-GPU token budget for the PPO actor update under dynamic batching.
53+
# null => follow max_model_len (the four verl *_token_len_per_gpu keys all track max_model_len).
54+
# <int> => override ONLY actor_rollout_ref.actor.ppo_max_token_len_per_gpu, so the PPO update
55+
# can pack larger micro-batches (squeeze more out of GPU memory) while keeping the
56+
# rollout/log_prob lengths tied to max_model_len.
57+
# Tip: set env AGENTJET_FIND_MAX_PPO_TOKEN_LEN=1 to auto-probe the hardware limit (the run will
58+
# raise with the discovered value, which you then paste here).
59+
ppo_max_token_len_per_gpu: null
60+
5261
multi_turn:
5362
# how many samples should be collected for each task run
5463
max_sample_per_task: 30

ajet/default_config/verl/config_auto_convertion_verl.jsonc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
"ajet.rollout.max_model_len": [
5050
"actor_rollout_ref.rollout.max_model_len",
5151
"actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu",
52-
"actor_rollout_ref.actor.ppo_max_token_len_per_gpu",
5352
"actor_rollout_ref.ref.log_prob_max_token_len_per_gpu"
5453
],
5554

ajet/utils/config_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ def recursive_copy(src_dict, dst_dict, parent_key=""):
166166
f"[Note]: Aligned parameter from [{from_key}] to [{to_key}] with value: [{value}]"
167167
)
168168

169+
# [AJET] ajet.rollout.ppo_max_token_len_per_gpu is the SOLE source of the actor's per-gpu PPO
170+
# token budget (it is deliberately NOT pulled in via the ajet.rollout.max_model_len mapping
171+
# above). None means "track max_model_len" (legacy behaviour); an explicit int decouples the
172+
# PPO update budget from max_model_len. The verl key only carries the resolved value so it can
173+
# reach the Ray actor (which receives the actor_rollout_ref subtree, not the ajet namespace).
174+
ppo_token_len = _dive_to_fetch_value(from_config, "ajet.rollout.ppo_max_token_len_per_gpu")
175+
if ppo_token_len is None:
176+
ppo_token_len = _dive_to_fetch_value(from_config, "ajet.rollout.max_model_len")
177+
_dive_to_set_value(to_config, "actor_rollout_ref.actor.ppo_max_token_len_per_gpu", ppo_token_len)
178+
logger.success(
179+
f"[Note]: Resolved [actor_rollout_ref.actor.ppo_max_token_len_per_gpu] = [{ppo_token_len}] "
180+
"from [ajet.rollout.ppo_max_token_len_per_gpu] (None => ajet.rollout.max_model_len)."
181+
)
182+
169183
# backbone specific safe guard
170184
to_config = align_parameter_safe_guard(to_config, backbone)
171185

ajet/utils/core_env_vars.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict:
6464
"OPENAI_BASE_URL",
6565
"API_KEY",
6666
"BASE_URL",
67+
"AGENTJET_FIND_MAX_PPO_TOKEN_LEN",
68+
"AGENTJET_FIND_MAX_START",
69+
"AGENTJET_FIND_MAX_CAP",
70+
"AGENTJET_FIND_MAX_TOL",
71+
"AGENTJET_FIND_MAX_BUDGET_S",
72+
"AGENTJET_FIND_MAX_SEQ",
6773
]
6874

6975
for var in optional_env_vars:

0 commit comments

Comments
 (0)