Skip to content

Commit 509fc59

Browse files
committed
improve interaction
1 parent 3581363 commit 509fc59

File tree

16 files changed

+628
-307
lines changed

16 files changed

+628
-307
lines changed
File renamed without changes.
File renamed without changes.

ajet/backbone/trainer_verl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def _update_interchange_server_status_flag(self, status: str):
460460
if self.config.ajet.enable_experimental_interchange_server:
461461
if self.config.ajet.enable_swarm_mode:
462462
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
463-
http_change_engine_status(self.config, status)
463+
http_change_engine_status(self.config, status, global_step=self.global_steps)
464464

465465
# #######################################
466466
# training loop

ajet/launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def main():
201201
exe_exp_base,
202202
exp_name,
203203
exp_config,
204-
) = prepare_experiment_config(yaml_path, exp_dir, args.backbone)
204+
) = prepare_experiment_config(yaml_path, exp_dir, args.backbone, storage=(not args.swarm_server))
205205

206206
# setup environment variables
207207
env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp)

ajet/task_rollout/native_parallel_worker.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,10 @@ def step_status_printer(self, observation_window):
7878
print_buf += [f"[finished]:{count} threads"]
7979
print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf))
8080

81-
if DEBUG:
82-
self._write_swarm_rollout_dynamic_log(observation_window)
83-
8481

8582
def _write_swarm_rollout_dynamic_log(self, observation_window):
86-
fp = "./swarm_rollout.dynamic.log"
83+
base_exp_dir = self.config.ajet.experiment_dir # {exp-dir}/{experiment_name}
84+
fp = f"{base_exp_dir}/swarm_rollout.dynamic.log"
8785
string_buffer = ""
8886
for info in observation_window["info"]:
8987
string_buffer += f"{info}\n"
@@ -189,25 +187,6 @@ def rollout_swarm( # noqa: C901
189187
completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {}
190188
executor_lock = threading.Lock()
191189

192-
# submit initial tasks
193-
dummy_task = Task(main_query="dummy task")
194-
for task_batch_index in range(n_task):
195-
for task_rollout_index in range(rollout_n):
196-
task_thread_index = task_batch_index * rollout_n + task_rollout_index
197-
observation_window["info"][task_thread_index] = f"\n\n\n\n[thread {task_thread_index} submit]\n"
198-
future = executor.submit(
199-
self.rollout_env_worker_loop,
200-
task=dummy_task,
201-
task_tag="",
202-
mode=mode,
203-
task_batch_index=task_batch_index,
204-
task_thread_index=task_thread_index,
205-
observation_window=observation_window,
206-
completed_task_id_map_ct=completed_task_id_map_ct,
207-
executor_lock=executor_lock,
208-
)
209-
futures.append(future)
210-
211190
def count_tasks(completed_task_id_map_ct):
212191
total_completed_episodes = 0
213192
total_completed_tasks = 0
@@ -239,7 +218,7 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool:
239218
counts = count_tasks(completed_task_id_map_ct)
240219
total_completed_episodes = counts["total_completed_episodes"]
241220
total_completed_tasks = counts["total_completed_tasks"]
242-
if total_completed_episodes > self.config.ajet.swarm_mode_sample_collection_max_cached_episodes // 2:
221+
if total_completed_episodes > (self.config.ajet.swarm_mode_sample_collection_max_cached_episodes // 5 * 4):
243222
logger.warning(
244223
f"Total cached episodes [{total_completed_episodes}] is going to exceed the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}], "
245224
f"but we are still not able to meet the stop condition (current finished tasks [{total_completed_tasks}], target tasks [{n_batch_task}]), this may cause memory issues. "
@@ -261,7 +240,7 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
261240
total_completed_episodes = counts["total_completed_episodes"]
262241
total_completed_tasks = counts["total_completed_tasks"]
263242
total_completed_non_dummy_tasks = counts["total_completed_non_dummy_tasks"]
264-
if total_completed_episodes > self.config.ajet.swarm_mode_sample_collection_max_cached_episodes // 2:
243+
if total_completed_episodes > (self.config.ajet.swarm_mode_sample_collection_max_cached_episodes // 5 * 4):
265244
logger.warning(
266245
f"Total cached episodes [{total_completed_episodes}] is going to exceed the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}], "
267246
f"but we are still not able to meet the stop condition (current finished tasks [{total_completed_non_dummy_tasks}], target tasks [{n_batch_task}]), this may cause memory issues. "
@@ -291,6 +270,9 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool:
291270
logger.error(f"Invalid swarm_mode_sample_collection_method: {self.config.ajet.swarm_mode_sample_collection_method}, fallback to default method: rollout_until_finish_enough_tasks")
292271
stop_condition = enough_finished_task_stop_condition
293272

273+
def is_already_soft_stopped():
274+
return all(observation_window["stop"])
275+
294276
# communicate with interchange server to stop new episode, and let threads finish current episode, then collect results and shutdown executor
295277
def stop_all_threads_soft():
296278
for k in range(len(observation_window["stop"])): observation_window["stop"][k] = True
@@ -303,6 +285,34 @@ def stop_all_threads_hard():
303285
http_change_engine_status(self.config, "ENGINE.WEIGHT_SYNCING")
304286
return
305287

288+
# pass a stop condition callback function to each thread, so that threads can check the stop condition whenever it finishes a cycle, this is faster than polling
289+
def stop_condition_callback(completed_task_id_map_ct):
290+
if stop_condition(completed_task_id_map_ct):
291+
if not is_already_soft_stopped():
292+
stop_all_threads_soft()
293+
return True
294+
return False
295+
296+
# submit initial tasks
297+
dummy_task = Task(main_query="dummy task")
298+
for task_batch_index in range(n_task):
299+
for task_rollout_index in range(rollout_n):
300+
task_thread_index = task_batch_index * rollout_n + task_rollout_index
301+
observation_window["info"][task_thread_index] = f"\n\n\n\n[thread {task_thread_index} submit]\n"
302+
future = executor.submit(
303+
self.rollout_env_worker_loop,
304+
task=dummy_task,
305+
task_tag="",
306+
mode=mode,
307+
task_batch_index=task_batch_index,
308+
task_thread_index=task_thread_index,
309+
observation_window=observation_window,
310+
completed_task_id_map_ct=completed_task_id_map_ct,
311+
executor_lock=executor_lock,
312+
stop_condition_callback=stop_condition_callback,
313+
)
314+
futures.append(future)
315+
306316
def update_rollout_result_array_preview(observation_window, completed_task_id_map_ct: Dict[str, List[BaseContextTracker]]):
307317
buffer = ""
308318
completed_tasks_details = {}
@@ -334,17 +344,19 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
334344
completed_tasks_details=completed_tasks_details,
335345
)
336346
http_update_rollout_pool_information(self.config, pool_info)
337-
338347
return
339348

340349
# loop and wait until stop condition is met, then stop threads and collect results
350+
CHECK_STATUS_INTERVAL = 4 # seconds
351+
PRINT_STATUS_INTERVAL = 12 # seconds
341352
cnt = 0
342353
while True:
343354
cnt += 1
344-
time.sleep(2)
345-
if (cnt % 5 == 0):
355+
time.sleep(CHECK_STATUS_INTERVAL)
356+
if (cnt % ( PRINT_STATUS_INTERVAL//CHECK_STATUS_INTERVAL ) == 0):
346357
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
347358
self.step_status_printer(observation_window)
359+
self._write_swarm_rollout_dynamic_log(observation_window)
348360
meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct)
349361
if meet_stop_condition_after_new_results:
350362
print("Sending soft stop signal to all threads...")
@@ -377,9 +389,8 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
377389
tracker.current_batch_reward = float(task_scalar_reward)
378390

379391
# for debugging
380-
if DEBUG:
381-
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
382-
self._write_swarm_rollout_dynamic_log(observation_window)
392+
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
393+
self._write_swarm_rollout_dynamic_log(observation_window)
383394

384395
return tracker_array
385396

@@ -428,9 +439,9 @@ def trajectories_to_samples(self, tracker_array: List[BaseContextTracker]) -> Li
428439
try:
429440
sample_arr = tracker.group_tokenize()
430441
except Exception as e:
442+
logger.bind(exception=True).exception("Error during tracker.group_tokenize()")
431443
raise e
432444
finally:
433-
logger.bind(exception=True).exception("Error during tracker.tokenize()") # for debugging
434445
tracker.generate_log(global_step=self.current_global_steps)
435446
if os.environ.get("BEST_LOGGER_PATH", None) and os.environ.get(
436447
"AJET_DEBUG", None

ajet/task_rollout/single_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def rollout_env_worker_loop(
159159
observation_window: dict,
160160
completed_task_id_map_ct: Dict[str, List[BaseContextTracker]],
161161
executor_lock: threading.Lock,
162+
stop_condition_callback=None,
162163
**kwargs,
163164
):
164165

@@ -198,6 +199,10 @@ def rollout_env_worker_loop(
198199

199200
cnt += 1
200201

202+
if stop_condition_callback is not None and stop_condition_callback(completed_task_id_map_ct):
203+
observation_window["info"][task_thread_index] += f"[thread {task_thread_index} observe stop_condition_callback true, returning]\n"
204+
return
205+
201206
if observation_window["stop"][task_thread_index]:
202207
observation_window["info"][task_thread_index] += f"[thread {task_thread_index} observe stop, returning]\n"
203208
return

0 commit comments

Comments
 (0)