Skip to content

Commit 8777bcb

Browse files
committed
stage swarm server
1 parent 920e4d5 commit 8777bcb

File tree

19 files changed

+859
-540
lines changed

19 files changed

+859
-540
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def fit(self): # noqa: C901
493493

494494
# perform validation before training
495495
# currently, we only support validation using the reward_function.
496-
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
496+
if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_tinkerscript_mode):
497497
val_metrics = self._validate()
498498
assert val_metrics, f"{val_metrics=}"
499499
pprint(f"Initial validation metrics: {val_metrics}")
@@ -784,6 +784,7 @@ def fit(self): # noqa: C901
784784
self.val_reward_fn is not None
785785
and self.config.trainer.test_freq > 0
786786
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
787+
and (not self.config.ajet.enable_tinkerscript_mode)
787788
):
788789
with marked_timer("testing", timing_raw, color="green"):
789790
val_metrics: dict = self._validate()
@@ -934,17 +935,16 @@ def _validate(self):
934935
self.async_rollout_manager.wake_up()
935936
main_val_dataset = self.get_eval_dataset()
936937

937-
logger.info("=" * 10 + "start validate rollout" + "=" * 10)
938+
logger.info("Starting validate rollout")
938939
context_tracker_arr, tasks, val_metrics = self.eval_dataset(
939940
target_dataset=main_val_dataset,
940941
target_dataset_name="main_val_dataset",
941942
mode="validate",
942943
epoch="test.1",
943944
)
944-
logger.info("=" * 10 + "end validate rollout" + "=" * 10)
945+
logger.info("Completed validate rollout")
945946
test_output_gen_batch = self.parallel_env.to_dataproto(context_tracker_arr)
946947
self.async_rollout_manager.sleep()
947-
logger.info("validation generation end")
948948

949949
# Store generated outputs
950950
output_ids = test_output_gen_batch.batch["responses"]

ajet/context_tracker/base_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):
115115

116116
self.workflow_task = workflow_task
117117
self.task_batch_index = self.workflow_task.task_batch_index
118-
self.task_tag = self.workflow_task.task_tag
119-
self.task_id = self.workflow_task.task_id
118+
self.task_tag: str = self.workflow_task.task_tag
119+
self.task_id: str = self.workflow_task.task_id
120120
self.episode_uuid = self.workflow_task.episode_uuid
121121

122122
self.config = config

ajet/task_reader/hf_dataset_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import List, Generator
21

32
import datasets
43

54
from ajet.schema.task import Task
5+
from typing import List, Generator
66
from ajet.task_reader.task_reader_base import BaseTaskReader
77

88

@@ -38,7 +38,7 @@ def _load_dataset_split(self, split: str):
3838
# Load from Hugging Face hub
3939
dataset = datasets.load_dataset(self.dataset_name, split=split)
4040
# shuffle dataset
41-
dataset = dataset.shuffle(seed=42)
41+
dataset = dataset.shuffle()
4242
except Exception as e:
4343
raise ValueError(
4444
f"Failed to load dataset '{self.dataset_name}' with split '{split}': {str(e)}"

ajet/task_rollout/native_parallel_worker.py

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import time
5-
from concurrent.futures import Future, ThreadPoolExecutor
5+
from concurrent.futures import Future, ThreadPoolExecutor, wait, ALL_COMPLETED, FIRST_COMPLETED
66
from typing import Dict, List, Literal
77
from urllib.parse import quote
88

@@ -59,6 +59,9 @@ def step_status_printer(self, observation_window):
5959
if start == -1:
6060
print_buf += [f"[finished]:{count} threads"]
6161
print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf))
62+
if "info" in observation_window:
63+
print_buf2 = "\t".join(observation_window["info"])
64+
print(print_buf2)
6265

6366
def rollout_static(
6467
self,
@@ -139,7 +142,9 @@ def rollout(
139142
epoch: str,
140143
) -> List[BaseContextTracker]:
141144
"""Delegate to dynamic rollout when oversampling is enabled."""
142-
if (
145+
if self.config.ajet.enable_tinkerscript_mode:
146+
return self.rollout_swarm(tasks, mode, epoch)
147+
elif (
143148
mode == "sample"
144149
and (self.rollout_n != 1)
145150
and self.config.ajet.rollout.enable_oversample
@@ -459,6 +464,144 @@ def rollout_dynamic( # noqa: C901
459464
return tracker_array
460465

461466

467+
468+
def rollout_swarm( # noqa: C901
469+
self,
470+
tasks: List[Task],
471+
mode: Literal["sample", "validate"],
472+
epoch: str,
473+
allow_sample_num_change=True,
474+
allow_force_stop=True,
475+
) -> List[BaseContextTracker]:
476+
"""
477+
Build a pool of threads to run context trackers in parallel,
478+
each thread re-spawn after complete, until reaching conditions to stop.
479+
"""
480+
481+
tracker_array: List[BaseContextTracker] = []
482+
assert mode != "validate"
483+
rollout_n = self.rollout_n
484+
n_task = len(tasks)
485+
self.current_token_count_time = time.time()
486+
487+
# initialize observation window
488+
observation_window: Dict[str, List[int | bool | str]] = {
489+
"info": ["" for _ in range(n_task * rollout_n)],
490+
"step": [0 for _ in range(n_task * rollout_n)],
491+
"stop": [False for _ in range(n_task * rollout_n)],
492+
"token": [0 for _ in range(n_task * rollout_n)],
493+
}
494+
executor = ThreadPoolExecutor(max_workers=self.max_parallel)
495+
futures: List[Future] = []
496+
completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {}
497+
498+
# submit initial tasks
499+
dummy_task = Task(main_query="dummy task")
500+
for task_batch_index in range(n_task):
501+
for task_rollout_index in range(rollout_n):
502+
task_thread_index = task_batch_index * rollout_n + task_rollout_index
503+
future = executor.submit(
504+
self.rollout_env_worker,
505+
task=dummy_task,
506+
task_tag="",
507+
mode=mode,
508+
task_batch_index=task_batch_index,
509+
task_thread_index=task_thread_index,
510+
observation_window=observation_window,
511+
)
512+
observation_window["info"][task_thread_index] = "1"
513+
futures.append(future)
514+
515+
def enough_sample_stop_condition(completed_task_id_map_ct) -> bool:
516+
n = 0
517+
for ct_list in completed_task_id_map_ct.values():
518+
n += len(ct_list)
519+
return (n >= n_task * rollout_n)
520+
521+
def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool:
522+
n_finish_roll_task = 0
523+
for ct_list in completed_task_id_map_ct.values():
524+
if len(ct_list) >= rollout_n:
525+
n_finish_roll_task += 1
526+
return (n_finish_roll_task >= n_task)
527+
528+
def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
529+
n_finish_roll_task = 0
530+
for ct_list in completed_task_id_map_ct.values():
531+
task_cmd_reward_array = [
532+
tracker.reward_structure.performance_reward for tracker in ct_list
533+
]
534+
if (len(ct_list) >= rollout_n):
535+
all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array)
536+
if all_equal: continue
537+
n_finish_roll_task += 1
538+
return (n_finish_roll_task >= n_task)
539+
540+
stop_condition = enough_sample_stop_condition
541+
542+
def force_stop_all_threads():
543+
for k in range(len(observation_window["stop"])):
544+
observation_window["stop"][k] = True
545+
return
546+
547+
tic = time.time()
548+
while True:
549+
# wait for a completed task
550+
done_arr, pending_arr = wait(futures, timeout=10, return_when=FIRST_COMPLETED)
551+
print(f"Done tasks: {len(done_arr)}, Pending tasks: {len(pending_arr)}")
552+
toc = time.time()
553+
if (toc - tic) > 8:
554+
tic = toc
555+
self.step_status_printer(observation_window)
556+
# get result
557+
for future in done_arr:
558+
ct: BaseContextTracker = future.result()
559+
if ct.task_id not in completed_task_id_map_ct:
560+
completed_task_id_map_ct[ct.task_id] = [ct]
561+
else:
562+
completed_task_id_map_ct[ct.task_id] += [ct]
563+
# if meet stop condition
564+
meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct)
565+
if meet_stop_condition_after_new_results:
566+
force_stop_all_threads()
567+
break
568+
else:
569+
# re-spawn new tasks for done futures
570+
for task_batch_index in range(n_task):
571+
for task_rollout_index in range(rollout_n):
572+
task_thread_index = task_batch_index * rollout_n + task_rollout_index
573+
has_done = (futures[task_thread_index] in done_arr)
574+
575+
observation_window["info"][task_thread_index] = str(int(observation_window["info"][task_thread_index]) + 1)
576+
observation_window["stop"][task_thread_index] = False
577+
observation_window["step"][task_thread_index] = 0
578+
579+
if has_done:
580+
print(f"Re-spawning thread {task_thread_index}...")
581+
future = executor.submit(
582+
self.rollout_env_worker,
583+
task=dummy_task,
584+
task_tag="",
585+
mode=mode,
586+
task_batch_index=task_batch_index,
587+
task_thread_index=task_thread_index,
588+
observation_window=observation_window,
589+
)
590+
futures[task_thread_index] = future
591+
592+
# wait for all threads to complete
593+
print('Finalizing all threads...')
594+
wait(futures, return_when=ALL_COMPLETED)
595+
596+
# build tracker_array
597+
print('Collecting results...')
598+
for ct_list in completed_task_id_map_ct.values():
599+
tracker_array.extend(ct_list)
600+
601+
# return all trackers
602+
return tracker_array
603+
604+
462605
class VerlRolloutManager(DynamicRolloutManager):
463606
"""High-level manager orchestrating rollouts and batch conversion."""
464607

ajet/task_rollout/single_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ajet.utils.retry import retry_with_backoff
1717
from ajet.utils.sample import get_sample_params
1818
from ajet.utils.testing_utils import TestFailException, TestSuccessException
19+
from ajet.task_runner.tinkerscript_runner import SwarmReceiveAbortException
1920

2021

2122
class BaseRolloutManager:
@@ -123,6 +124,8 @@ def rollout_env_worker(
123124
tracker = agent_runner.execute(
124125
workflow_task=workflow_task,
125126
)
127+
except SwarmReceiveAbortException as exc: # noqa: BLE001
128+
return None # type: ignore
126129
except TestSuccessException as e:
127130
logger.success(
128131
f"env_worker.agent_flow completed with TestSuccessException: {e.args}"

0 commit comments

Comments
 (0)