Skip to content

Commit 484d1bc

Browse files
committed
feat: implement PeriodicDrainThreadPoolExecutor for improved task management and auto-retry functionality
1 parent 29b4169 commit 484d1bc

File tree

2 files changed

+56
-25
lines changed

2 files changed

+56
-25
lines changed

ajet/utils/thread_executors.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from concurrent.futures import ThreadPoolExecutor
22
from ajet.utils.sington import singleton
3+
from loguru import logger
34
import threading
45

56

@@ -41,3 +42,40 @@ def wrapped_fn(*args, **kwargs):
4142
def shutdown(self, wait=True):
4243
self.executor.shutdown(wait=wait)
4344

45+
class PeriodicDrainThreadPoolExecutor:
46+
"""A ThreadPoolExecutor that bounds the number of pending tasks via a semaphore."""
47+
48+
def __init__(self, workers=100, auto_retry=True):
49+
self._max_workers = workers
50+
self._executor = ThreadPoolExecutor(max_workers=workers)
51+
self._submitted_count = 0
52+
self._auto_retry = auto_retry
53+
54+
def submit(self, fn, *args, **kwargs):
55+
"""Submit a task, blocking if the pending queue is full."""
56+
57+
def retry_wrapper(func, arg):
58+
while True:
59+
try:
60+
return func(arg)
61+
except Exception as e:
62+
logger.exception(f"[run_episodes_until_all_complete] Error executing episode: {e}. Retrying...")
63+
64+
if self._auto_retry:
65+
return self._executor.submit(retry_wrapper, fn, *args, **kwargs)
66+
else:
67+
return self._executor.submit(fn, *args, **kwargs)
68+
69+
def submit_with_periodic_drain(self, fn, *args, **kwargs):
70+
"""Submit a task, draining all in-flight work every `drain_every_n_job` submissions."""
71+
drain_every_n_job = self._max_workers
72+
if self._submitted_count > 0 and self._submitted_count % drain_every_n_job == 0:
73+
self._executor.shutdown(wait=True)
74+
self._executor = ThreadPoolExecutor(max_workers=self._max_workers)
75+
76+
self._submitted_count += 1
77+
return self.submit(fn, *args, **kwargs)
78+
79+
def shutdown(self, wait=True):
80+
"""Shut down the underlying executor."""
81+
self._executor.shutdown(wait=wait)

tutorial/example_math_swarm/math.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,27 @@
1+
import os
12
import re
23
import requests
34
from textwrap import dedent
45
from ajet.schema.task import Task, WorkflowOutput
56
from ajet.copilot.job import AgentJetJob
67
from ajet.task_reader import RouterTaskReader
78
from ajet.utils.retry import retry_with_backoff
8-
from ajet.utils.thread_executors import BoundedThreadPoolExecutor
9+
from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
910
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
1011
from ajet.tuner_lib.experimental.interchange_utils import SwarmThrottlePolicy
1112
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
1213
from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient, SwarmThrottlePolicy
1314

14-
# --------- configurations that take effect locally -------------
15-
LOCAL_GRPO_N = 4 # grpo group size
16-
LOCAL_NUM_EPOCH = 10000
17-
LOCAL_NUM_EPOCH = 1
18-
LOCAL_MAX_PARALLEL = 64
19-
LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main"
20-
REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
21-
22-
# --------- configurations that take effect remotely -------------
23-
REMOTE_BATCH_SIZE = 32
24-
REMOTE_ALLOCATE_GPU_PER_NODE = 4
25-
REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
26-
2715
# python -m tutorial.example_math_swarm.math
2816

29-
class WeightUpdatedHalfway(Exception):
30-
"""Raised when the remote side starts updating model weights halfway through an episode."""
17+
GRPO_N = 4 # grpo group size
18+
NUM_EPOCH = 10000
19+
DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main"
20+
AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
3121

22+
REMOTE_BATCH_SIZE = 32
23+
REMOTE_ALLOCATE_GPU_PER_NODE = 4
24+
REMOTE_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
3225

3326
def main():
3427

@@ -37,21 +30,21 @@ def main():
3730
reader_type = "huggingface_dat_repo",
3831
reader_config = AjetTaskReader(
3932
huggingface_dat_repo = HuggingfaceDatRepo(
40-
dataset_path = LOCAL_DATASET_PATH
33+
dataset_path = DATASET_PATH
4134
)
4235
)
4336
)
4437

4538
# # Hand shake with remote swarm server
46-
swarm_worker = SwarmClient(REMOTE_SWARM_URL)
39+
swarm_worker = SwarmClient(AJET_SWARM_URL)
4740
swarm_worker.auto_sync_train_config_and_start_engine(
4841
AgentJetJob(
4942
experiment_name="math_gsm8k_grpo",
5043
algorithm="grpo",
5144
n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE,
52-
model=REMOTE_TRAIN_MODEL_01,
45+
model=REMOTE_TRAIN_MODEL,
5346
batch_size=REMOTE_BATCH_SIZE,
54-
num_repeat=LOCAL_GRPO_N,
47+
num_repeat=GRPO_N,
5548
)
5649
)
5750

@@ -62,7 +55,7 @@ def rollout(task):
6255
throttle_policy=SwarmThrottlePolicy(
6356
ratio=0.5,
6457
expected_batch_size=REMOTE_BATCH_SIZE,
65-
expected_num_repeat=LOCAL_GRPO_N,
58+
expected_num_repeat=GRPO_N,
6659
current_task_id=task.task_id
6760
)
6861
)
@@ -76,11 +69,11 @@ def rollout(task):
7669
except:
7770
pass
7871

79-
executor = BoundedThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL)
80-
for epoch in range(LOCAL_NUM_EPOCH):
72+
executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True)
73+
for _ in range(NUM_EPOCH):
8174
for _, task in enumerate(dataset.generate_training_tasks()):
82-
for _ in range(LOCAL_GRPO_N):
83-
executor.submit(rollout, task)
75+
for _ in range(GRPO_N):
76+
executor.submit_with_periodic_drain(fn=rollout, task=task)
8477

8578
return None
8679

0 commit comments

Comments
 (0)