Skip to content

Commit 3157658

Browse files
committed
fix state machine bugs
1 parent 8777bcb commit 3157658

20 files changed

Lines changed: 533 additions & 313 deletions

ajet/__init__.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
from ajet.copilot.job import AgentJetJob
2-
from ajet.schema.task import WorkflowOutput, WorkflowTask
3-
from ajet.tuner import AjetTuner
4-
from ajet.workflow import Workflow
5-
from ajet.utils.vsdb import vscode_conditional_breakpoint as bp
1+
__version__ = "0.1.0"
62

73
__all__ = [
84
"Workflow",
@@ -13,4 +9,29 @@
139
"bp"
1410
]
1511

16-
__version__ = "0.1.0"
12+
_LAZY_IMPORTS = {
13+
"AjetTuner": "ajet.tuner",
14+
"AgentJetJob": "ajet.copilot.job",
15+
"WorkflowOutput": "ajet.schema.task",
16+
"WorkflowTask": "ajet.schema.task",
17+
"Workflow": "ajet.workflow",
18+
"bp": "ajet.utils.vsdb",
19+
}
20+
21+
_ATTR_MAPPING = {
22+
"bp": "vscode_conditional_breakpoint"
23+
}
24+
25+
def __getattr__(name):
26+
if name in _LAZY_IMPORTS:
27+
import importlib
28+
module_path = _LAZY_IMPORTS[name]
29+
module = importlib.import_module(module_path)
30+
31+
attr_name = _ATTR_MAPPING.get(name, name)
32+
value = getattr(module, attr_name) # type: ignore
33+
34+
globals()[name] = value
35+
return value
36+
37+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

ajet/context_tracker/multiagent_tracking.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ def __init__(
4949
tokenizer: PreTrainedTokenizer,
5050
config,
5151
should_interrupt_fn,
52+
should_interrupt_hard_fn,
5253
generated_token_callback_fn,
5354
**kwargs,
5455
):
5556
super().__init__(config, tokenizer, **kwargs)
5657
self.tokenizer = tokenizer
5758
self.should_interrupt_fn = should_interrupt_fn
59+
self.should_interrupt_hard_fn = should_interrupt_hard_fn
5860
self.generated_token_callback_fn = generated_token_callback_fn
5961
self.context_overflow = False
6062
self.output_kwargs = {}

ajet/copilot/job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
algorithm: str = "grpo",
4545
n_gpu_for_infer: int | None = None, # only for trinity backbone
4646
grpo_n: int = 8,
47+
batch_size: int = 32,
4748
tinkerscript_mode: bool = True,
4849
*kwargs,
4950
) -> None:
@@ -60,6 +61,7 @@ def __init__(
6061
self.config.ajet.trainer_common.n_gpus_per_node = n_gpu
6162
self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm
6263
self.config.ajet.rollout.num_repeat = grpo_n
64+
self.config.ajet.data.train_batch_size = batch_size
6365
if n_gpu_for_infer is None and backbone == "trinity":
6466
raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.")
6567
if (n_gpu_for_infer is not None) and backbone == "verl":

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ ajet:
290290
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
291291
interchange_server_port: 'auto'
292292
num_fastapi_process: 2 # 1, 2 or 4 is fine
293-
max_fastapi_threads: 128 # 64 or 128 is fine
293+
max_fastapi_threads: 512 # 64 or 128 is fine
294294
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
295295
already_started: False # do not edit, used by `tinkerscript`
296296

ajet/default_config/ajet_ts_default.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ ajet:
77

88
model:
99
# which model should be trained
10-
path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct
10+
path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct
1111

1212
rollout:
1313
# the path to the workflow class
@@ -29,10 +29,13 @@ ajet:
2929
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
3030
interchange_server_port: 10086
3131
num_fastapi_process: 2 # 1, 2 or 4 is fine
32-
max_fastapi_threads: 128 # 64 or 128 is fine
32+
max_fastapi_threads: 512 # 64 or 128 is fine
3333
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
3434
already_started: False # do not edit, used by `tinkerscript`
3535

36+
rollout:
37+
# maximum number of parallel environments / simulate workers
38+
max_env_worker: 128
3639

3740

3841
# ------------------ 不需要修改 ------------------

ajet/task_rollout/native_parallel_worker.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,19 @@
88

99
import numpy as np
1010
import torch
11+
import threading
1112
from loguru import logger
1213
from tensordict import TensorDict
1314
from torch.nn.utils.rnn import pad_sequence
1415
from tqdm import tqdm
1516
from verl import DataProto
1617
from verl.utils.torch_functional import pad_sequence_to_length
1718

18-
from ajet.context_tracker.basic_tracker import BaseContextTracker
1919
from ajet.schema.task import Task
2020
from ajet.schema.trajectory import Sample
2121
from ajet.task_rollout.single_worker import BaseRolloutManager
22+
from ajet.context_tracker.basic_tracker import BaseContextTracker
23+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
2224

2325

2426
class DynamicRolloutManager(BaseRolloutManager):
@@ -481,33 +483,39 @@ def rollout_swarm( # noqa: C901
481483
tracker_array: List[BaseContextTracker] = []
482484
assert mode != "validate"
483485
rollout_n = self.rollout_n
484-
n_task = len(tasks)
486+
n_batch_task = len(tasks)
487+
n_task = min(len(tasks), self.max_parallel // rollout_n)
488+
assert n_task > 0, f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = {n_task}"
485489
self.current_token_count_time = time.time()
486490

487491
# initialize observation window
488492
observation_window: Dict[str, List[int | bool | str]] = {
489493
"info": ["" for _ in range(n_task * rollout_n)],
490494
"step": [0 for _ in range(n_task * rollout_n)],
491495
"stop": [False for _ in range(n_task * rollout_n)],
496+
"hard_stop": [False for _ in range(n_task * rollout_n)],
492497
"token": [0 for _ in range(n_task * rollout_n)],
493498
}
494499
executor = ThreadPoolExecutor(max_workers=self.max_parallel)
495500
futures: List[Future] = []
496501
completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {}
502+
executor_lock = threading.Lock()
497503

498504
# submit initial tasks
499505
dummy_task = Task(main_query="dummy task")
500506
for task_batch_index in range(n_task):
501507
for task_rollout_index in range(rollout_n):
502508
task_thread_index = task_batch_index * rollout_n + task_rollout_index
503509
future = executor.submit(
504-
self.rollout_env_worker,
510+
self.rollout_env_worker_loop,
505511
task=dummy_task,
506512
task_tag="",
507513
mode=mode,
508514
task_batch_index=task_batch_index,
509515
task_thread_index=task_thread_index,
510516
observation_window=observation_window,
517+
completed_task_id_map_ct=completed_task_id_map_ct,
518+
executor_lock=executor_lock,
511519
)
512520
observation_window["info"][task_thread_index] = "1"
513521
futures.append(future)
@@ -516,14 +524,15 @@ def enough_sample_stop_condition(completed_task_id_map_ct) -> bool:
516524
n = 0
517525
for ct_list in completed_task_id_map_ct.values():
518526
n += len(ct_list)
519-
return (n >= n_task * rollout_n)
527+
print(f"Current collected samples: {n}, target: {n_batch_task * rollout_n}")
528+
return (n >= n_batch_task * rollout_n)
520529

521530
def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool:
522531
n_finish_roll_task = 0
523532
for ct_list in completed_task_id_map_ct.values():
524533
if len(ct_list) >= rollout_n:
525534
n_finish_roll_task += 1
526-
return (n_finish_roll_task >= n_task)
535+
return (n_finish_roll_task >= n_batch_task)
527536

528537
def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
529538
n_finish_roll_task = 0
@@ -535,63 +544,39 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
535544
all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array)
536545
if all_equal: continue
537546
n_finish_roll_task += 1
538-
return (n_finish_roll_task >= n_task)
547+
return (n_finish_roll_task >= n_batch_task)
539548

540549
stop_condition = enough_sample_stop_condition
541550

542-
def force_stop_all_threads():
543-
for k in range(len(observation_window["stop"])):
544-
observation_window["stop"][k] = True
551+
def stop_all_threads_soft():
552+
for k in range(len(observation_window["stop"])): observation_window["stop"][k] = True
553+
http_change_engine_status(self.config, "ENGINE.ROLLING_POST")
554+
return
555+
556+
def stop_all_threads_hard():
557+
for k in range(len(observation_window["hard_stop"])): observation_window["hard_stop"][k] = True
558+
http_change_engine_status(self.config, "ENGINE.WEIGHT_SYNCING")
545559
return
546560

547-
tic = time.time()
561+
cnt = 0
548562
while True:
549-
# wait for a completed task
550-
done_arr, pending_arr = wait(futures, timeout=10, return_when=FIRST_COMPLETED)
551-
print(f"Done tasks: {len(done_arr)}, Pending tasks: {len(pending_arr)}")
552-
toc = time.time()
553-
if (toc - tic) > 8:
554-
tic = toc
563+
cnt += 1
564+
time.sleep(2)
565+
if (cnt % 5 == 0):
555566
self.step_status_printer(observation_window)
556-
# get result
557-
for future in done_arr:
558-
ct: BaseContextTracker = future.result()
559-
if ct.task_id not in completed_task_id_map_ct:
560-
completed_task_id_map_ct[ct.task_id] = [ct]
561-
else:
562-
completed_task_id_map_ct[ct.task_id] += [ct]
563-
# if meet stop condition
564567
meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct)
565568
if meet_stop_condition_after_new_results:
566-
force_stop_all_threads()
569+
print("Sending soft stop signal to all threads...")
570+
stop_all_threads_soft()
567571
break
568-
else:
569-
# re-spawn new tasks for done futures
570-
for task_batch_index in range(n_task):
571-
for task_rollout_index in range(rollout_n):
572-
task_thread_index = task_batch_index * rollout_n + task_rollout_index
573-
has_done = (futures[task_thread_index] in done_arr)
574-
575-
observation_window["info"][task_thread_index] = str(int(observation_window["info"][task_thread_index]) + 1)
576-
observation_window["stop"][task_thread_index] = False
577-
observation_window["step"][task_thread_index] = 0
578-
579-
if has_done:
580-
print(f"Re-spawning thread {task_thread_index}...")
581-
future = executor.submit(
582-
self.rollout_env_worker,
583-
task=dummy_task,
584-
task_tag="",
585-
mode=mode,
586-
task_batch_index=task_batch_index,
587-
task_thread_index=task_thread_index,
588-
observation_window=observation_window,
589-
)
590-
futures[task_thread_index] = future
591572

592573
# wait for all threads to complete
593574
print('Finalizing all threads...')
594-
wait(futures, return_when=ALL_COMPLETED)
575+
executor.shutdown(wait=True)
576+
577+
# stop all threads hard
578+
print("Sending hard stop signal to all threads...")
579+
stop_all_threads_hard()
595580

596581
# build tracker_array
597582
print('Collecting results...')

ajet/task_rollout/single_worker.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Single worker primitives for environment rollouts."""
22

33
import uuid
4+
import time
5+
import threading
46
from typing import Literal
57

68
from loguru import logger
79
from omegaconf import DictConfig
10+
from typing import Dict, List, Literal
811
from transformers.tokenization_utils import PreTrainedTokenizer
912

1013
from ajet.context_tracker.basic_tracker import BaseContextTracker
@@ -14,9 +17,9 @@
1417
from ajet.task_runner.general_runner import GeneralRunner
1518
from ajet.task_runner.tinkerscript_runner import TinkerScriptRunner
1619
from ajet.utils.retry import retry_with_backoff
20+
from ajet.utils.retry import SwarmReceiveAbortException
1721
from ajet.utils.sample import get_sample_params
1822
from ajet.utils.testing_utils import TestFailException, TestSuccessException
19-
from ajet.task_runner.tinkerscript_runner import SwarmReceiveAbortException
2023

2124

2225
class BaseRolloutManager:
@@ -125,6 +128,7 @@ def rollout_env_worker(
125128
workflow_task=workflow_task,
126129
)
127130
except SwarmReceiveAbortException as exc: # noqa: BLE001
131+
print('SwarmReceiveAbortException caught in rollout_env_worker')
128132
return None # type: ignore
129133
except TestSuccessException as e:
130134
logger.success(
@@ -141,3 +145,54 @@ def rollout_env_worker(
141145
raise e
142146

143147
return tracker
148+
149+
150+
def rollout_env_worker_loop(
151+
self,
152+
task: Task,
153+
task_batch_index: int,
154+
task_tag: str,
155+
mode: Literal["sample", "validate"],
156+
task_thread_index: int,
157+
observation_window: dict,
158+
completed_task_id_map_ct: Dict[str, List[BaseContextTracker]],
159+
executor_lock: threading.Lock,
160+
**kwargs,
161+
):
162+
try:
163+
cnt = 1
164+
while True:
165+
166+
if observation_window["stop"][task_thread_index]:
167+
print('rollout_env_worker_loop received stop signal, exiting...')
168+
return
169+
170+
observation_window["info"][task_thread_index] = str(cnt)
171+
tracker = self.rollout_env_worker(
172+
task=task,
173+
task_batch_index=task_batch_index,
174+
task_tag=task_tag,
175+
mode=mode,
176+
task_thread_index=task_thread_index,
177+
observation_window=observation_window,
178+
**kwargs,
179+
)
180+
181+
# avoid write conflict
182+
if tracker and tracker.reward_structure:
183+
with executor_lock:
184+
if tracker.task_id not in completed_task_id_map_ct:
185+
completed_task_id_map_ct[tracker.task_id] = [tracker]
186+
else:
187+
completed_task_id_map_ct[tracker.task_id] += [tracker]
188+
cnt += 1
189+
if observation_window["stop"][task_thread_index]:
190+
return
191+
else:
192+
del tracker
193+
194+
except Exception as e:
195+
logger.exception(
196+
f"encounter exception in env_worker_loop error={e.args}"
197+
)
198+
raise e

ajet/task_runner/base_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@ def get_judge(self) -> BaseJudge: # type: ignore
4949

5050
def runner_hooks(self, observation_window, task_thread_index, workflow_task):
5151
def should_interrupt_fn() -> bool:
52-
if (observation_window["stop"] is not None) and observation_window["stop"][
53-
task_thread_index
54-
]: # Check if the thread should stop (because other threads have completed, making this thread useless)
52+
if (observation_window["stop"] is not None) and observation_window["stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless)
53+
return True
54+
return False
55+
56+
def should_interrupt_hard_fn() -> bool:
57+
if (observation_window["hard_stop"] is not None) and observation_window["hard_stop"][task_thread_index]: # Check if the thread should stop (because other threads have completed, making this thread useless)
5558
return True
5659
return False
5760

@@ -60,6 +63,7 @@ def generated_token_callback_fn(token_array):
6063

6164
return {
6265
"should_interrupt_fn": should_interrupt_fn,
66+
"should_interrupt_hard_fn": should_interrupt_hard_fn,
6367
"generated_token_callback_fn": generated_token_callback_fn,
6468
}
6569

ajet/task_runner/general_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

2-
from ajet import AjetTuner
3-
from ajet import WorkflowOutput
2+
from ajet.tuner import AjetTuner
3+
from ajet.schema.task import WorkflowOutput, WorkflowTask
44
from ajet.context_tracker.multiagent_tracking import (
55
MultiAgentContextTracker,
66
)

0 commit comments

Comments
 (0)