Skip to content

Commit b80bc21

Browse files
authored
Add swarm training mode (#8)
* tinkerscript-v1 * improve tinkerscript * feat(tinkerscript): Add comprehensive design blueprint and workflow documentation * fix mermaid * Remove limitations and Chinese version from documentation Removed limitations section and Chinese version from TinkerScript documentation. * Clarify relationship between TinkerScript and Tinker * Update tinkerscript.md * remove trinity * Add AgentJet image to TinkerScript documentation * feat: implement TinkerScript server functionality and enhance configuration syncing * feat: enhance TinkerScript integration with improved engine status handling and configuration updates * feat: enhance TinkerScript functionality with improved engine status handling and episode management * stage eval code ( to be tested ) * union_gen_batch_via_task_id is to be tested * stage dataset io improvement * stage academic translation agent * stage swarm server * fix state machine bugs * rename to agentjet swarm * update pro-academic-trans agent * revise pro-trans * make rollout more robust * enhance error logging during tracker.tokenize() for better debugging * improve readability * delete exit message
1 parent 3bb6f29 commit b80bc21

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+3521
-424
lines changed

ajet/__init__.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
from ajet.copilot.job import AgentJetJob
2-
from ajet.schema.task import WorkflowOutput, WorkflowTask
3-
from ajet.tuner import AjetTuner
4-
from ajet.workflow import Workflow
5-
from ajet.utils.vsdb import vscode_conditional_breakpoint as bp
1+
__version__ = "0.1.0"
62

73
__all__ = [
84
"Workflow",
@@ -13,4 +9,29 @@
139
"bp"
1410
]
1511

16-
__version__ = "0.1.0"
12+
_LAZY_IMPORTS = {
13+
"AjetTuner": "ajet.tuner",
14+
"AgentJetJob": "ajet.copilot.job",
15+
"WorkflowOutput": "ajet.schema.task",
16+
"WorkflowTask": "ajet.schema.task",
17+
"Workflow": "ajet.workflow",
18+
"bp": "ajet.utils.vsdb",
19+
}
20+
21+
_ATTR_MAPPING = {
22+
"bp": "vscode_conditional_breakpoint"
23+
}
24+
25+
def __getattr__(name):
26+
if name in _LAZY_IMPORTS:
27+
import importlib
28+
module_path = _LAZY_IMPORTS[name]
29+
module = importlib.import_module(module_path)
30+
31+
attr_name = _ATTR_MAPPING.get(name, name)
32+
value = getattr(module, attr_name) # type: ignore
33+
34+
globals()[name] = value
35+
return value
36+
37+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

ajet/backbone/main_verl.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,29 +22,33 @@
2222
import hydra
2323
import ray
2424
from beast_logger import print_dict
25-
from loguru import logger
26-
from omegaconf import OmegaConf
25+
from omegaconf import DictConfig, OmegaConf
2726
from verl.trainer.ppo.reward import load_reward_manager
2827
from verl.utils.device import is_cuda_available
28+
from verl.utils.dataset.rl_dataset import collate_fn
29+
from torch.utils.data import Dataset as TorchDataset
2930

31+
# Create training and validation datasets.
32+
from ajet.task_reader import RouterTaskReader, task_to_standard_dataset
33+
from ajet.utils.process_dataset import create_rl_sampler
3034
from ajet.utils.core_env_vars import get_runtime_env
3135
from ajet.utils.launch_utils import set_loguru_default_color
3236

3337
set_loguru_default_color()
3438

3539

3640
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
37-
def main(config):
41+
def main(config: DictConfig) -> None:
3842
"""Main entry point for PPO training with Hydra configuration management.
3943
4044
Args:
41-
config_dict: Hydra configuration dictionary containing training parameters.
45+
config: Hydra configuration dictionary containing training parameters.
4246
"""
4347
run_ppo(config)
4448

4549

4650
# Define a function to run the PPO-like training process
47-
def run_ppo(config) -> None:
51+
def run_ppo(config: DictConfig) -> None:
4852
"""Initialize Ray cluster and run distributed PPO training process.
4953
5054
Args:
@@ -56,7 +60,6 @@ def run_ppo(config) -> None:
5660
if not ray.is_initialized():
5761
# this is for local ray cluster
5862
runtime_env = get_runtime_env(config)
59-
print_dict(runtime_env["env_vars"], "runtime_env")
6063
ray.init(
6164
runtime_env=runtime_env,
6265
num_cpus=config.ray_init.num_cpus,
@@ -110,6 +113,7 @@ def run(self, config):
110113
# Print the initial configuration. `resolve=True` will evaluate symbolic values.
111114
from pprint import pprint
112115

116+
from loguru import logger
113117
from omegaconf import OmegaConf
114118
from verl.utils.fs import copy_to_local
115119

@@ -227,21 +231,13 @@ def run(self, config):
227231
resource_pool_spec=resource_pool_spec, mapping=mapping
228232
)
229233

230-
from verl.utils.dataset.rl_dataset import collate_fn
231-
232-
# Create training and validation datasets.
233-
from ajet.task_reader import (
234-
RouterTaskReader,
235-
task_to_standard_dataset,
236-
)
237-
from ajet.utils.process_dataset import create_rl_sampler
238-
239234
task_reader = RouterTaskReader(
240235
config.ajet.task_reader.type,
241236
config.ajet.task_reader,
242237
)
243-
val_dataset = task_to_standard_dataset(task_reader.get_validation_tasks())
244-
train_dataset = task_to_standard_dataset(task_reader.get_training_tasks())
238+
239+
train_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_training_tasks) # type: ignore
240+
val_dataset: TorchDataset = task_to_standard_dataset(task_reader.generate_validation_tasks) # type: ignore
245241
train_sampler = create_rl_sampler(config.data, train_dataset)
246242

247243
from ajet.backbone.trainer_verl import AjetRayPPOTrainer

ajet/backbone/main_vllm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def run(config):
144144
max_parallel = config.ajet.debug.debug_max_parallel
145145
n_task = config.ajet.debug.debug_first_n_tasks
146146
vllm_port = config.ajet.debug.debug_vllm_port
147+
enable_swarm_mode = config.ajet.enable_swarm_mode
147148

148149
# --------- init ---------
149150
async_rollout_manager = ChatCompletionScheduler(
@@ -166,8 +167,10 @@ def run(config):
166167
tasks = task_reader.get_validation_tasks()
167168
logger.info(tasks[:n_task])
168169
ctx_tracker = parallel_env.rollout(
169-
tasks=tasks[:n_task], mode="sample", epoch="1"
170-
) # "sample" or "validate"
170+
tasks=tasks[:n_task],
171+
mode="sample" if not enable_swarm_mode else "sample-ts", # type: ignore
172+
epoch="1"
173+
)
171174
_ = parallel_env.to_dataproto(ctx_tracker)
172175

173176

@@ -186,6 +189,9 @@ def main(config):
186189
if config.ajet.enable_experimental_interchange_server:
187190
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
188191
start_interchange_server(config)
192+
if config.ajet.enable_swarm_mode:
193+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
194+
http_change_engine_status(config, "ENGINE.ROLLING")
189195

190196
def companion_launch():
191197
import torch

ajet/backbone/trainer_trinity.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,11 +206,9 @@ def __init__(self, config):
206206

207207
dataset_segments = []
208208
if "train" in self.split:
209-
dataset_segments.append(task_to_standard_dataset(task_reader.get_training_tasks()))
209+
dataset_segments.append(task_to_standard_dataset(task_reader.generate_training_tasks)) # type: ignore
210210
if "val" in self.split:
211-
dataset_segments.append(
212-
task_to_standard_dataset(task_reader.get_validation_tasks())
213-
)
211+
dataset_segments.append(task_to_standard_dataset(task_reader.generate_validation_tasks)) # type: ignore
214212
if not dataset_segments:
215213
raise ValueError(
216214
f"Unsupported split '{self.split}'. Expected to contain 'train' or 'val'."

ajet/backbone/trainer_verl.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,29 @@ def parse_reward_from_dataproto(data: DataProto, return_dict=False) -> dict | to
9999
return reward_tensor
100100

101101

102-
def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto):
102+
def union_gen_batch_via_task_id(tasks, batch: DataProto, gen_batch_output: DataProto, discard_original_batch=False):
103103
"""
104104
Union the gen_batch_output with the batch based on task_id.
105105
"""
106-
map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)}
107-
gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"]
108-
indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids]
109-
batch_extend = batch.select_idxs(indices)
110-
batch_final = batch_extend.union(gen_batch_output)
111-
return batch_final
106+
if not discard_original_batch:
107+
map_task_id_to_index = {t.task_id: i for i, t in enumerate(tasks)}
108+
gen_task_task_ids = gen_batch_output.non_tensor_batch["task_ids"]
109+
indices = [map_task_id_to_index[tid] for tid in gen_task_task_ids]
110+
batch_extend = batch.select_idxs(indices)
111+
batch_final = batch_extend.union(gen_batch_output)
112+
return batch_final
113+
else:
114+
gen_batch_output.non_tensor_batch['uid'] = gen_batch_output.non_tensor_batch["task_ids"]
115+
task_id_counter = {}
116+
for i, tid in enumerate(gen_batch_output.non_tensor_batch["task_ids"]):
117+
if tid in task_id_counter:
118+
task_id_counter[tid] += 1
119+
else:
120+
task_id_counter[tid] = 1
121+
current_id = task_id_counter[tid]
122+
gen_batch_output.non_tensor_batch['rollout_ids'][i] = f"T{tid}R{current_id}"
123+
logger.info(f'task_id_counter: {task_id_counter}')
124+
return gen_batch_output
112125

113126

114127
def compute_advantage(
@@ -443,6 +456,12 @@ def init_workers(self):
443456
tokenizer=self.tokenizer,
444457
)
445458

459+
def _update_interchange_server_status_flag(self, status: str):
460+
if self.config.ajet.enable_experimental_interchange_server:
461+
if self.config.ajet.enable_swarm_mode:
462+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
463+
http_change_engine_status(self.config, status)
464+
446465
# #######################################
447466
# training loop
448467
# #######################################
@@ -474,7 +493,7 @@ def fit(self): # noqa: C901
474493

475494
# perform validation before training
476495
# currently, we only support validation using the reward_function.
477-
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
496+
if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_swarm_mode):
478497
val_metrics = self._validate()
479498
assert val_metrics, f"{val_metrics=}"
480499
pprint(f"Initial validation metrics: {val_metrics}")
@@ -547,12 +566,13 @@ def fit(self): # noqa: C901
547566

548567
with marked_timer("step", timing_raw):
549568
# generate a batch
550-
logger.info("=== + rollout step begin ===")
569+
logger.info("rollout step begin")
551570
with marked_timer("gen", timing_raw, color="red"):
552571
assert self.async_rollout_mode
553-
logger.info("=== wake up begin ===")
572+
logger.info("wake up begin")
554573
self.async_rollout_manager.wake_up()
555-
logger.info("=== wake up end ===")
574+
self._update_interchange_server_status_flag("ENGINE.ROLLING")
575+
logger.info("wake up end")
556576
tasks: List[Task] = [
557577
dict_to_ajet_task(dict(
558578
task_id=gen_batch.non_tensor_batch["task_id"][i],
@@ -571,15 +591,17 @@ def fit(self): # noqa: C901
571591
]
572592
)
573593
)
574-
logger.info("=" * 10 + "start fit rollout" + "=" * 10)
594+
logger.info("start fit rollout")
575595
self.parallel_env.current_global_steps = self.global_steps
576596
context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout(
577597
tasks, mode="sample", epoch=f"train.{epoch}"
578598
)
579-
logger.info("=" * 10 + "end fit rollout" + "=" * 10)
580-
logger.info("begin to convert context_tracker_arr to dataproto")
599+
600+
# from ajet import bp; bp("BATCH")
601+
602+
logger.info("end fit rollout")
581603
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
582-
logger.info("end convertion")
604+
logger.info("end dataproto convertion")
583605

584606
success_rate = [
585607
traj.reward_structure.success_rate for traj in context_tracker_arr
@@ -622,17 +644,17 @@ def fit(self): # noqa: C901
622644
logger.info(
623645
f"gen_batch_output.info batch.keys={gen_batch_output.batch.keys()}"
624646
)
647+
self._update_interchange_server_status_flag("ENGINE.WEIGHT_SYNCING")
625648
self.async_rollout_manager.sleep()
626-
logger.info("=== - rollout step end ===")
649+
logger.info("rollout step end")
627650

628-
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
629-
raise NotImplementedError("REMAX is not supported in GRPO yet.")
630651

631652
batch.non_tensor_batch["uid"] = np.array(
632653
[str(uuid.uuid4()) for _ in range(len(batch.batch))],
633654
dtype=object,
634655
)
635-
batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output)
656+
discard_original_batch = self.config.ajet.enable_swarm_mode
657+
batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output, discard_original_batch)
636658
batch.batch["response_mask"] = compute_response_mask(batch)
637659

638660
if "response_mask" not in batch.batch.keys():
@@ -666,7 +688,7 @@ def fit(self): # noqa: C901
666688
)
667689

668690
# recompute old_log_probs
669-
logger.info("=== + compute log_probs begin ===")
691+
logger.info("+ compute log_probs begin")
670692
with marked_timer("old_log_prob", timing_raw, color="blue"):
671693
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
672694
entropys = old_log_prob.batch["entropys"]
@@ -764,6 +786,7 @@ def fit(self): # noqa: C901
764786
self.val_reward_fn is not None
765787
and self.config.trainer.test_freq > 0
766788
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
789+
and (not self.config.ajet.enable_swarm_mode)
767790
):
768791
with marked_timer("testing", timing_raw, color="green"):
769792
val_metrics: dict = self._validate()
@@ -914,17 +937,16 @@ def _validate(self):
914937
self.async_rollout_manager.wake_up()
915938
main_val_dataset = self.get_eval_dataset()
916939

917-
logger.info("=" * 10 + "start validate rollout" + "=" * 10)
940+
logger.info("Starting validate rollout")
918941
context_tracker_arr, tasks, val_metrics = self.eval_dataset(
919942
target_dataset=main_val_dataset,
920943
target_dataset_name="main_val_dataset",
921944
mode="validate",
922945
epoch="test.1",
923946
)
924-
logger.info("=" * 10 + "end validate rollout" + "=" * 10)
947+
logger.info("Completed validate rollout")
925948
test_output_gen_batch = self.parallel_env.to_dataproto(context_tracker_arr)
926949
self.async_rollout_manager.sleep()
927-
logger.info("validation generation end")
928950

929951
# Store generated outputs
930952
output_ids = test_output_gen_batch.batch["responses"]
@@ -938,7 +960,8 @@ def _validate(self):
938960
dtype=object,
939961
)
940962
tasks = tasks[: len(main_val_dataset)]
941-
test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch)
963+
discard_original_batch = self.config.ajet.enable_swarm_mode
964+
test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch, discard_original_batch)
942965
# test_batch = test_batch.union(test_output_gen_batch)
943966
test_batch.meta_info["validate"] = True
944967

0 commit comments

Comments
 (0)