Skip to content

Commit b15983a

Browse files
committed
make rollout more robust
1 parent 47812cb commit b15983a

File tree

7 files changed

+68
-32
lines changed

7 files changed

+68
-32
lines changed

ajet/task_rollout/native_parallel_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def step_status_printer(self, observation_window):
6161
if start == -1:
6262
print_buf += [f"[finished]:{count} threads"]
6363
print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf))
64-
if "info" in observation_window:
65-
print_buf2 = "\t".join(observation_window["info"])
66-
print(print_buf2)
64+
# if "info" in observation_window:
65+
# print_buf2 = "\t".join(observation_window["info"])
66+
# print(print_buf2)
6767

6868
def rollout_static(
6969
self,

ajet/task_runner/swarm_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def register_episode_and_wait_output(
8888
logger.warning(f"Received reset command for episode {episode_uuid}.")
8989
context_tracker.reset()
9090
zmq_socket.send_string("ack")
91+
continue
9192
elif message == "RUNNER.SPECIAL.ABORT":
9293
logger.warning(f"Received abort command for episode {episode_uuid}.")
9394
context_tracker.reset()
@@ -104,8 +105,8 @@ def register_episode_and_wait_output(
104105
raise exc
105106

106107
finally:
108+
tuner.terminate_episode() # this is very important to avoid resource leak
107109
zmq_socket.close()
108-
tuner.terminate_episode()
109110
if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path)
110111

111112
return final_output

ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, server_url: str):
2626
self.server_url = server_url
2727
self.client_uuid = str(uuid.uuid4())
2828
self.previous_warning_time = 0
29+
self.record_episode_expire_time = {}
2930

3031

3132
def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple[str, OpenaiBaseUrlAndApiKey]:
@@ -48,6 +49,7 @@ def begin_episode(self, allow_discard_timeout=60, episode_type="train") -> Tuple
4849
resp.raise_for_status()
4950
data = ClaimEpisodeResponse.model_validate(resp.json())
5051
episode_uuid = data.episode_uuid
52+
self.record_episode_expire_time[episode_uuid] = time.time() + allow_discard_timeout
5153

5254
if data.success:
5355
episode_uuid = data.episode_uuid
@@ -82,6 +84,11 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut
8284
logger.error("No episode to end.")
8385
return
8486

87+
remain_time = self.record_episode_expire_time.get(episode_uuid, 0) - time.time()
88+
if remain_time < 0:
89+
logger.warning(f"Episode {episode_uuid} has expired (expired {remain_time} seconds ago). Please use a larger `allow_discard_timeout` when `begin_episode`. Skipping end_episode.")
90+
return
91+
8592
try:
8693
task_id = task.task_id
8794
workflow_output.metadata["task_id"] = task_id
@@ -131,7 +138,7 @@ def abort_episode(self, episode_uuid: str):
131138
data = EndEpisodeResponse.model_validate(resp.json())
132139

133140
if data.success:
134-
logger.info(f"Ended episode {episode_uuid}")
141+
logger.info(f"Aborted episode {episode_uuid}")
135142
else:
136143
logger.error(f"Failed to end episode {episode_uuid}")
137144

ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,20 @@ def register_enable_swarm_mode_routes(
6161
# ------------------------------------------------------------------------------------------------
6262

6363
async def find_claimed_episodes_that_need_to_be_unclaimed() -> List[str]:
64-
result = []
64+
to_unclaim_episodes = []
6565
current_time = time.time()
6666

6767
for k, v in shared_mem_dict.items():
6868
if is_key_epsisode_status(k):
6969
es:EpisodeStatus = v
7070
if es.episode_status == "claimed":
7171
if (current_time - es.latest_activity_timestamp) > es.allow_discard_timeout:
72-
result.append(es.episode_uuid)
72+
to_unclaim_episodes.append(es.episode_uuid)
7373

74-
for episode_uuid in result:
74+
for episode_uuid in to_unclaim_episodes:
7575
await _revert_episode_to_unclaimed(episode_uuid, shared_mem_dict, shared_mem_dict_lock)
7676

77-
return result
77+
return to_unclaim_episodes
7878

7979
def _context_tracker_reset_blocking(episode_uuid, shared_mem_dict): # must async
8080
# send message to context tracker
@@ -110,6 +110,8 @@ def _context_tracker_reset_blocking(episode_uuid, shared_mem_dict): # must asyn
110110
async def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock):
111111
# check status again, because other thread may have changed it
112112
if shared_mem_dict[ep_key(episode_uuid)].episode_status != "claimed":
113+
if episode_uuid in shared_mem_dict['unclaimed_episodes']: pass
114+
else: shared_mem_dict['unclaimed_episodes'] += [episode_uuid]
113115
return
114116

115117
# reset context tracker
@@ -126,17 +128,15 @@ async def _revert_episode_to_unclaimed(episode_uuid: str, shared_mem_dict, share
126128
es.allow_discard_timeout = -1
127129
with shared_mem_dict_lock:
128130
shared_mem_dict[ep_key(episode_uuid)] = es
129-
if episode_uuid in shared_mem_dict['unclaimed_episodes']:
130-
pass
131-
else:
132-
shared_mem_dict['unclaimed_episodes'] += [episode_uuid]
131+
if episode_uuid in shared_mem_dict['unclaimed_episodes']: pass
132+
else: shared_mem_dict['unclaimed_episodes'] += [episode_uuid]
133133

134134
def _delete_episode_record(episode_uuid: str, shared_mem_dict, shared_mem_dict_lock):
135135

136136
with shared_mem_dict_lock:
137137
# remove episode record
138138
if ep_key(episode_uuid) in shared_mem_dict:
139-
del shared_mem_dict[ep_key(episode_uuid)]
139+
del shared_mem_dict[ep_key(episode_uuid)] # RM--
140140
logger.info(f"Deleted episode record for {episode_uuid}.")
141141
# remove from unclaimed list if present
142142
if episode_uuid in shared_mem_dict['unclaimed_episodes']:
@@ -499,7 +499,17 @@ async def end_episode(req: EndEpisodeRequest):
499499

500500
# send workflow_output to zmq
501501
assert 'episodes' in shared_mem_dict
502-
episode_type = shared_mem_dict[ep_key(episode_uuid)].episode_type
502+
ep_stat = shared_mem_dict[ep_key(episode_uuid)]
503+
episode_type = ep_stat.episode_type
504+
episode_status = ep_stat.episode_status
505+
client_uuid_recorded = ep_stat.client_uuid
506+
if client_uuid_recorded != client_uuid:
507+
logger.error(f"[server] Episode {episode_uuid} is claimed by different client: {client_uuid_recorded}, but got {client_uuid}.")
508+
raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} is claimed by different client: {client_uuid_recorded}, but got {client_uuid}.")
509+
510+
if episode_status != "claimed":
511+
logger.error(f"[server] Episode {episode_uuid} is not in claimed status.")
512+
raise HTTPException(status_code=400, detail=f"Episode {episode_uuid} is not in claimed status, maybe you take too long to submit.")
503513

504514
if episode_type == "train":
505515
# _register_final_episode_output_blocking(episode_uuid, workflow_output, shared_mem_dict, shared_mem_dict_lock) # must async

ajet/utils/thread_executors.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,42 @@
1+
from concurrent.futures import ThreadPoolExecutor
12
from ajet.utils.sington import singleton
2-
import concurrent.futures
3-
3+
import threading
44

55

66
@singleton
77
class SharedInterchangeThreadExecutor:
88
def __init__(self, max_workers=64):
9-
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
9+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
1010

11-
def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor:
11+
def get_shared_executor(self) -> ThreadPoolExecutor:
1212
return self.executor
1313

1414

1515

1616
@singleton
1717
class SharedInferenceTrackerThreadExecutor:
1818
def __init__(self, max_workers=64):
19-
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
19+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
2020

21-
def get_shared_executor(self) -> concurrent.futures.ThreadPoolExecutor:
21+
def get_shared_executor(self) -> ThreadPoolExecutor:
2222
return self.executor
23+
24+
25+
class BoundedThreadPoolExecutor:
26+
def __init__(self, max_workers, max_queue_size=100):
27+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
28+
self.semaphore = threading.Semaphore(max_queue_size)
29+
30+
def submit(self, fn, *args, **kwargs):
31+
self.semaphore.acquire()
32+
33+
def wrapped_fn(*args, **kwargs):
34+
try:
35+
return fn(*args, **kwargs)
36+
finally:
37+
self.semaphore.release()
38+
39+
return self.executor.submit(wrapped_fn, *args, **kwargs)
40+
41+
def shutdown(self, wait=True):
42+
self.executor.shutdown(wait=wait)

tutorial/example_academic_trans/trans.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
4949
grader = TranslationQualityGrader(
5050
model=OpenAIChatModel(base_url=grader_base_url, api_key=grader_api_key, model="qwen3-max-2026-01-23")
5151
)
52-
grader_score = asyncio.run(grader.aevaluate(original_text=abstract, translation=final_translation))
52+
grader_score = asyncio.run(asyncio.wait_for(grader.aevaluate(original_text=abstract, translation=final_translation), timeout=120))
5353
raw_reward = grader_score.score
5454
print(f"Grader Score: {grader_score.score}, Reason: {grader_score.reason}, Metadata: {grader_score.metadata}")
5555
return WorkflowOutput(reward=raw_reward, metadata={
@@ -111,7 +111,8 @@ def detect_hard_proper_nouns(messages, base_url, api_key, abstract, rough_transl
111111
response = client.chat.completions.create(
112112
model="qwen3-max-2026-01-23",
113113
messages=messages,
114-
extra_body={"enable_thinking":True}
114+
timeout=60,
115+
# extra_body={"enable_thinking":True}
115116
)
116117
fix_nouns = response.choices[0].message.content
117118
messages += [

tutorial/example_academic_trans/trans_roll.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ajet.tuner_lib.weight_tuner.experimental.as_swarm_client import SwarmClient
99
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
1010
from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
11-
from ajet import WorkflowOutput
11+
from ajet.utils.thread_executors import BoundedThreadPoolExecutor
1212
from ajet.schema.task import Task
1313
from ajet.task_reader import RouterTaskReader
1414
from ajet.utils.retry import retry_with_backoff
@@ -56,7 +56,7 @@ def main():
5656
model=REMOTE_TRAIN_MODEL_01,
5757
batch_size=REMOTE_BATCH_SIZE,
5858
grpo_n=LOCAL_GRPO_N,
59-
)
59+
),
6060
)
6161

6262
def rollout(task):
@@ -80,20 +80,17 @@ def rollout(task):
8080
logger.exception("Exception during rollout group", e)
8181

8282
task_batch = []
83+
executor = BoundedThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL, max_queue_size=LOCAL_MAX_PARALLEL*2)
8384
for i, task in enumerate(dataset.generate_training_tasks()):
8485
task_batch += [task]
8586

8687
if len(task_batch) == REMOTE_BATCH_SIZE:
8788
print('*********** beginning a new batch of tasks... ***********')
88-
with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor:
89-
for task in task_batch:
90-
executor.submit(rollout, task)
91-
executor.shutdown(wait=True)
89+
for task in task_batch:
90+
executor.submit(rollout, task)
9291
task_batch = []
93-
print('*********** tasks completed, wait a minute... ***********')
94-
time.sleep(60)
95-
9692

93+
executor.shutdown(wait=True)
9794
return None
9895

9996

0 commit comments

Comments
 (0)