Skip to content

Commit 1bb02c8

Browse files
committed
Add layered_summon parameter to AgentJetJob and related configurations
1 parent 7967f61 commit 1bb02c8

7 files changed

Lines changed: 72 additions & 23 deletions

File tree

ajet/copilot/job.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class AgentJetJob:
6464
lora_alpha: LoRA alpha scaling factor (default 16).
6565
lora_target_modules: Target modules for LoRA adaptation (default 'all-linear').
6666
lora_load_format: Load format for LoRA weights (default 'auto').
67+
layered_summon: Enable layered summon for LoRA (default False).
6768
gpu_memory_utilization: GPU memory utilization for vLLM engine (default 0.85).
6869
lr: Learning rate for optimizer (default 1e-6).
6970
"""
@@ -93,6 +94,7 @@ def __init__(
9394
lora_alpha: int | None = None,
9495
lora_target_modules: str | None = None,
9596
lora_load_format: str | None = None,
97+
layered_summon: bool | None = None,
9698
gpu_memory_utilization: float | None = None,
9799
lr: float | None = None,
98100
) -> None:
@@ -136,6 +138,7 @@ def __init__(
136138
self.lora_alpha: int = cast(int, lora_alpha)
137139
self.lora_target_modules: str = cast(str, lora_target_modules)
138140
self.lora_load_format: str = cast(str, lora_load_format)
141+
self.layered_summon: bool = cast(bool, layered_summon)
139142
self.gpu_memory_utilization: float = cast(float, gpu_memory_utilization)
140143
self.lr: float = cast(float, lr)
141144

@@ -164,6 +167,7 @@ def __init__(
164167
"ajet.lora.lora_alpha": "lora_alpha",
165168
"ajet.lora.target_modules": "lora_target_modules",
166169
"ajet.lora.load_format": "lora_load_format",
170+
"ajet.lora.layered_summon": "layered_summon",
167171
"ajet.rollout.gpu_memory_utilization": "gpu_memory_utilization",
168172
"ajet.trainer_common.optim.lr": "lr",
169173
}
@@ -194,6 +198,8 @@ def __init__(
194198
if self.lora_rank > 0:
195199
if self.lora_load_format != "safetensors":
196200
raise ValueError(f"When lora_rank > 0, lora_load_format must be 'safetensors', got '{self.lora_load_format}'")
201+
if not self.layered_summon:
202+
raise ValueError("When lora_rank > 0, layered_summon must be True")
197203
if self.lr is None:
198204
raise ValueError("lr should be provided for lora training")
199205
if self.lr <= 1e-5:

ajet/default_config/ajet_config_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class AjetLora:
4646
lora_alpha: int = 16
4747
target_modules: str = "all-linear"
4848
load_format: str = "auto"
49+
layered_summon: bool = False
4950

5051

5152
@dataclass

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ ajet:
316316
lora_alpha: 16
317317
target_modules: all-linear
318318
load_format: auto
319+
layered_summon: false
319320

320321

321322
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature

ajet/default_config/verl/config_auto_convertion_verl.jsonc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"ajet.lora.lora_alpha": "actor_rollout_ref.model.lora_alpha",
3636
"ajet.lora.target_modules": "actor_rollout_ref.model.target_modules",
3737
"ajet.lora.load_format": "actor_rollout_ref.rollout.load_format",
38+
"ajet.lora.layered_summon": "actor_rollout_ref.rollout.layered_summon",
3839

3940
"ajet.trainer_common.total_training_steps": "trainer.total_training_steps",
4041
"ajet.trainer_common.save_freq": "trainer.save_freq",

ajet/utils/config_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,21 @@ def align_parameter_safe_guard(config: dict, backbone: str) -> dict:
183183
if backbone == "verl" and isinstance(config["trainer"]["logger"], str):
184184
config["trainer"]["logger"] = ["console", config["trainer"]["logger"]]
185185

186-
# special: LoRA requires safetensors load_format
186+
# special: LoRA requires safetensors load_format and layered_summon
187187
if backbone == "verl":
188188
lora_rank = config.get("actor_rollout_ref", {}).get("model", {}).get("lora_rank", 0)
189189
load_format = config.get("actor_rollout_ref", {}).get("rollout", {}).get("load_format", "auto")
190+
layered_summon = config.get("actor_rollout_ref", {}).get("rollout", {}).get("layered_summon", False)
190191
if lora_rank > 0 and load_format != "safetensors":
191192
raise ValueError(
192193
f"LoRA training (lora_rank={lora_rank}) requires load_format='safetensors', "
193194
f"but got load_format='{load_format}'. Please set `ajet.lora.load_format: safetensors` in your config."
194195
)
196+
if lora_rank > 0 and not layered_summon:
197+
raise ValueError(
198+
f"LoRA training (lora_rank={lora_rank}) requires layered_summon=True, "
199+
f"but got layered_summon={layered_summon}. Please set `ajet.lora.layered_summon: true` in your config."
200+
)
195201

196202
# special: trinity train_batch_size
197203
if backbone == "trinity":

tutorial/example_train_multi_model/trans_roll_lora.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def main():
5252
lora_rank=32,
5353
lora_alpha=32,
5454
lora_load_format="safetensors",
55+
layered_summon=True,
5556
lr=3e-4,
57+
5658
)
5759

5860
job_7b = AgentJetJob(
@@ -67,6 +69,7 @@ def main():
6769
lora_rank=32,
6870
lora_alpha=32,
6971
lora_load_format="safetensors",
72+
layered_summon=True,
7073
lr=3e-4,
7174
)
7275

tutorial/example_werewolves_swarm/agent_roll_v2.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ class ExperimentConfig:
108108
max_parallel: int = 64
109109
discard_episode_timeout: int = 240
110110
project_name: str = "werewolves_multi_model"
111+
# Random player split mode: at each episode start, randomly split
112+
# good-side players among trainable models (ignoring role-based assignment)
113+
random_player_split: bool = False
111114

112115
def __post_init__(self):
113116
# Validate that all trainable roles are from the same faction
@@ -150,26 +153,29 @@ def __init__(
150153
swarm_clients: Dict[str, SwarmClient],
151154
opponent_model: str,
152155
opponent_url: str,
156+
random_player_split: bool = False,
153157
):
154158
self.model_configs = model_configs
155159
self.swarm_clients = swarm_clients
156160
self.opponent_model = opponent_model
157161
self.opponent_url = opponent_url
162+
self.random_player_split = random_player_split
158163

159164
# Build role -> model_id mapping (for roles without index constraints)
160165
self.role_to_model: Dict[str, str] = {}
161166
# Build (role, index) -> model_id mapping (for indexed assignments)
162167
self.role_index_to_model: Dict[Tuple[str, int], str] = {}
163168

164-
for mc in model_configs:
165-
for role in mc.roles:
166-
if mc.role_indices and role in mc.role_indices:
167-
# Index-based assignment
168-
for idx in mc.role_indices[role]:
169-
self.role_index_to_model[(role, idx)] = mc.model_id
170-
else:
171-
# Role-based assignment (all instances)
172-
self.role_to_model[role] = mc.model_id
169+
if not random_player_split:
170+
for mc in model_configs:
171+
for role in mc.roles:
172+
if mc.role_indices and role in mc.role_indices:
173+
# Index-based assignment
174+
for idx in mc.role_indices[role]:
175+
self.role_index_to_model[(role, idx)] = mc.model_id
176+
else:
177+
# Role-based assignment (all instances)
178+
self.role_to_model[role] = mc.model_id
173179

174180
def get_trainable_targets(self) -> List[str]:
175181
"""Get all trainable roles across all models."""
@@ -216,17 +222,38 @@ async def execute(
216222
# Track which model each player uses
217223
player_to_model: Dict[int, str] = {}
218224

225+
# For random_player_split mode: randomly assign good-side players to models
226+
player_to_model_split: Dict[int, str] = {}
227+
if self.random_player_split:
228+
# Identify all good-side player indices
229+
good_player_indices = [i for i, role in enumerate(roles) if role in GOOD_ROLES]
230+
# Shuffle and split 50/50
231+
np.random.shuffle(good_player_indices)
232+
half = len(good_player_indices) // 2
233+
model_ids = [mc.model_id for mc in self.model_configs]
234+
for i, player_idx in enumerate(good_player_indices):
235+
# First half -> M1, second half -> M2
236+
model_id_for_player = model_ids[0] if i < half else model_ids[1]
237+
player_to_model_split[player_idx] = model_id_for_player
238+
logger.info(f"Random player split: M1={[p for p, m in player_to_model_split.items() if m == model_ids[0]]}, "
239+
f"M2={[p for p, m in player_to_model_split.items() if m == model_ids[1]]}")
240+
219241
# Initialize agents
220242
players = []
221243
for i, role in enumerate(roles):
222244
# Get the index of this role instance (0, 1, 2 for werewolves, etc.)
223245
role_idx = role_counters.get(role, 0)
224246
role_counters[role] = role_idx + 1
225247

226-
# Try to find model: first by (role, index), then by role only
227-
model_id = self.role_index_to_model.get((role, role_idx))
228-
if model_id is None:
229-
model_id = self.role_to_model.get(role)
248+
# Determine model_id based on assignment mode
249+
if self.random_player_split:
250+
# In random split mode, use player-based assignment
251+
model_id = player_to_model_split.get(i)
252+
else:
253+
# Try to find model: first by (role, index), then by role only
254+
model_id = self.role_index_to_model.get((role, role_idx))
255+
if model_id is None:
256+
model_id = self.role_to_model.get(role)
230257

231258
if model_id is None:
232259
# Non-trainable role - use opponent model
@@ -326,6 +353,8 @@ def setup(self):
326353
lora_rank=mc.lora.rank if mc.lora.enabled else None,
327354
lora_alpha=mc.lora.alpha if mc.lora.enabled else None,
328355
lora_target_modules=mc.lora.target_modules if mc.lora.enabled else None,
356+
lr=3e-4,
357+
layered_summon=True,
329358
)
330359

331360
self.jobs[mc.model_id] = job
@@ -358,6 +387,7 @@ def run(self):
358387
swarm_clients=self.swarm_clients,
359388
opponent_model=self.config.opponent_model,
360389
opponent_url=self.config.opponent_url,
390+
random_player_split=self.config.random_player_split,
361391
)
362392

363393
def rollout(task: Task):
@@ -405,7 +435,7 @@ def rollout(task: Task):
405435
# Predefined Experiment Configurations
406436
# ============================================================================
407437

408-
VERSION = "v2"
438+
VERSION = "v3"
409439

410440

411441
def get_exp1_config() -> ExperimentConfig:
@@ -529,35 +559,36 @@ def get_exp3_config() -> ExperimentConfig:
529559

530560
def get_exp4_config() -> ExperimentConfig:
531561
"""
532-
Experiment 4: Two models with random 50/50 split of good roles.
533-
- M1 (14B-LoRA): 50% random non-werewolf characters
534-
- M2 (14B-LoRA): remaining 50% non-werewolf characters
562+
Experiment 4: Two models with random 50/50 split of good-side players.
563+
- M1 (14B-LoRA): randomly selected 50% of good-side players per episode
564+
- M2 (14B-LoRA): remaining 50% of good-side players
535565
- Opponents (235B): werewolf
536566
537-
Role assignment is randomized per game.
567+
At the start of each episode, the 6 good-side players (3 villagers,
568+
1 seer, 1 witch, 1 hunter) are randomly split: 3 players go to M1,
569+
3 players go to M2. This is player-based, not role-based assignment.
538570
"""
539-
# For simplicity, we do a fixed 50/50 split here
540-
# In practice, the split could be randomized per game
541571
return ExperimentConfig(
542572
model_configs=[
543573
ModelConfig(
544574
model_id="M1",
545575
swarm_url="http://localhost:10086",
546576
model_path=DEFAULT_MODEL_14B,
547-
roles=["villager", "seer"], # ~50% of good roles
577+
roles=GOOD_ROLES, # All good roles (for validation only)
548578
lora=LoraConfig(enabled=True, rank=32, alpha=32),
549579
experiment_name=f"werewolves_exp4_m1_half_{VERSION}",
550580
),
551581
ModelConfig(
552582
model_id="M2",
553583
swarm_url="http://localhost:10087",
554584
model_path=DEFAULT_MODEL_14B,
555-
roles=["witch", "hunter"], # ~50% of good roles
585+
roles=GOOD_ROLES, # All good roles (for validation only)
556586
lora=LoraConfig(enabled=True, rank=32, alpha=32),
557587
experiment_name=f"werewolves_exp4_m2_half_{VERSION}",
558588
),
559589
],
560590
project_name="werewolves_exp4_random_split",
591+
random_player_split=True, # Enable random player-based assignment
561592
)
562593

563594

0 commit comments

Comments
 (0)