|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | import time |
5 | | -from concurrent.futures import Future, ThreadPoolExecutor |
| 5 | +from concurrent.futures import Future, ThreadPoolExecutor, wait, ALL_COMPLETED, FIRST_COMPLETED |
6 | 6 | from typing import Dict, List, Literal |
7 | 7 | from urllib.parse import quote |
8 | 8 |
|
@@ -59,6 +59,9 @@ def step_status_printer(self, observation_window): |
59 | 59 | if start == -1: |
60 | 60 | print_buf += [f"[finished]:{count} threads"] |
61 | 61 | print(f"Rollout progress ({token_gen_per_sec_str}): " + " // ".join(print_buf)) |
| 62 | + if "info" in observation_window: |
| 63 | + print_buf2 = "\t".join(observation_window["info"]) |
| 64 | + print(print_buf2) |
62 | 65 |
|
63 | 66 | def rollout_static( |
64 | 67 | self, |
@@ -139,7 +142,9 @@ def rollout( |
139 | 142 | epoch: str, |
140 | 143 | ) -> List[BaseContextTracker]: |
141 | 144 | """Delegate to dynamic rollout when oversampling is enabled.""" |
142 | | - if ( |
| 145 | + if self.config.ajet.enable_tinkerscript_mode: |
| 146 | + return self.rollout_swarm(tasks, mode, epoch) |
| 147 | + elif ( |
143 | 148 | mode == "sample" |
144 | 149 | and (self.rollout_n != 1) |
145 | 150 | and self.config.ajet.rollout.enable_oversample |
@@ -459,6 +464,144 @@ def rollout_dynamic( # noqa: C901 |
459 | 464 | return tracker_array |
460 | 465 |
|
461 | 466 |
|
| 467 | + |
| 468 | + def rollout_swarm( # noqa: C901 |
| 469 | + self, |
| 470 | + tasks: List[Task], |
| 471 | + mode: Literal["sample", "validate"], |
| 472 | + epoch: str, |
| 473 | + allow_sample_num_change=True, |
| 474 | + allow_force_stop=True, |
| 475 | + ) -> List[BaseContextTracker]: |
| 476 | + """ |
| 477 | + Build a pool of threads to run context trackers in parallel, |
| 478 | + each thread re-spawn after complete, until reaching conditions to stop. |
| 479 | + """ |
| 480 | + |
| 481 | + tracker_array: List[BaseContextTracker] = [] |
| 482 | + assert mode != "validate" |
| 483 | + rollout_n = self.rollout_n |
| 484 | + n_task = len(tasks) |
| 485 | + self.current_token_count_time = time.time() |
| 486 | + |
| 487 | + # initialize observation window |
| 488 | + observation_window: Dict[str, List[int | bool | str]] = { |
| 489 | + "info": ["" for _ in range(n_task * rollout_n)], |
| 490 | + "step": [0 for _ in range(n_task * rollout_n)], |
| 491 | + "stop": [False for _ in range(n_task * rollout_n)], |
| 492 | + "token": [0 for _ in range(n_task * rollout_n)], |
| 493 | + } |
| 494 | + executor = ThreadPoolExecutor(max_workers=self.max_parallel) |
| 495 | + futures: List[Future] = [] |
| 496 | + completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {} |
| 497 | + |
| 498 | + # submit initial tasks |
| 499 | + dummy_task = Task(main_query="dummy task") |
| 500 | + for task_batch_index in range(n_task): |
| 501 | + for task_rollout_index in range(rollout_n): |
| 502 | + task_thread_index = task_batch_index * rollout_n + task_rollout_index |
| 503 | + future = executor.submit( |
| 504 | + self.rollout_env_worker, |
| 505 | + task=dummy_task, |
| 506 | + task_tag="", |
| 507 | + mode=mode, |
| 508 | + task_batch_index=task_batch_index, |
| 509 | + task_thread_index=task_thread_index, |
| 510 | + observation_window=observation_window, |
| 511 | + ) |
| 512 | + observation_window["info"][task_thread_index] = "1" |
| 513 | + futures.append(future) |
| 514 | + |
| 515 | + def enough_sample_stop_condition(completed_task_id_map_ct) -> bool: |
| 516 | + n = 0 |
| 517 | + for ct_list in completed_task_id_map_ct.values(): |
| 518 | + n += len(ct_list) |
| 519 | + return (n >= n_task * rollout_n) |
| 520 | + |
| 521 | + def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: |
| 522 | + n_finish_roll_task = 0 |
| 523 | + for ct_list in completed_task_id_map_ct.values(): |
| 524 | + if len(ct_list) >= rollout_n: |
| 525 | + n_finish_roll_task += 1 |
| 526 | + return (n_finish_roll_task >= n_task) |
| 527 | + |
| 528 | + def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: |
| 529 | + n_finish_roll_task = 0 |
| 530 | + for ct_list in completed_task_id_map_ct.values(): |
| 531 | + task_cmd_reward_array = [ |
| 532 | + tracker.reward_structure.performance_reward for tracker in ct_list |
| 533 | + ] |
| 534 | + if (len(ct_list) >= rollout_n): |
| 535 | + all_equal = all(x == task_cmd_reward_array[0] for x in task_cmd_reward_array) |
| 536 | + if all_equal: continue |
| 537 | + n_finish_roll_task += 1 |
| 538 | + return (n_finish_roll_task >= n_task) |
| 539 | + |
| 540 | + stop_condition = enough_sample_stop_condition |
| 541 | + |
| 542 | + def force_stop_all_threads(): |
| 543 | + for k in range(len(observation_window["stop"])): |
| 544 | + observation_window["stop"][k] = True |
| 545 | + return |
| 546 | + |
| 547 | + tic = time.time() |
| 548 | + while True: |
| 549 | + # wait for a completed task |
| 550 | + done_arr, pending_arr = wait(futures, timeout=10, return_when=FIRST_COMPLETED) |
| 551 | + print(f"Done tasks: {len(done_arr)}, Pending tasks: {len(pending_arr)}") |
| 552 | + toc = time.time() |
| 553 | + if (toc - tic) > 8: |
| 554 | + tic = toc |
| 555 | + self.step_status_printer(observation_window) |
| 556 | + # get result |
| 557 | + for future in done_arr: |
| 558 | + ct: BaseContextTracker = future.result() |
| 559 | + if ct.task_id not in completed_task_id_map_ct: |
| 560 | + completed_task_id_map_ct[ct.task_id] = [ct] |
| 561 | + else: |
| 562 | + completed_task_id_map_ct[ct.task_id] += [ct] |
| 563 | + # if meet stop condition |
| 564 | + meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct) |
| 565 | + if meet_stop_condition_after_new_results: |
| 566 | + force_stop_all_threads() |
| 567 | + break |
| 568 | + else: |
| 569 | + # re-spawn new tasks for done futures |
| 570 | + for task_batch_index in range(n_task): |
| 571 | + for task_rollout_index in range(rollout_n): |
| 572 | + task_thread_index = task_batch_index * rollout_n + task_rollout_index |
| 573 | + has_done = (futures[task_thread_index] in done_arr) |
| 574 | + |
| 575 | + observation_window["info"][task_thread_index] = str(int(observation_window["info"][task_thread_index]) + 1) |
| 576 | + observation_window["stop"][task_thread_index] = False |
| 577 | + observation_window["step"][task_thread_index] = 0 |
| 578 | + |
| 579 | + if has_done: |
| 580 | + print(f"Re-spawning thread {task_thread_index}...") |
| 581 | + future = executor.submit( |
| 582 | + self.rollout_env_worker, |
| 583 | + task=dummy_task, |
| 584 | + task_tag="", |
| 585 | + mode=mode, |
| 586 | + task_batch_index=task_batch_index, |
| 587 | + task_thread_index=task_thread_index, |
| 588 | + observation_window=observation_window, |
| 589 | + ) |
| 590 | + futures[task_thread_index] = future |
| 591 | + |
| 592 | + # wait for all threads to complete |
| 593 | + print('Finalizing all threads...') |
| 594 | + wait(futures, return_when=ALL_COMPLETED) |
| 595 | + |
| 596 | + # build tracker_array |
| 597 | + print('Collecting results...') |
| 598 | + for ct_list in completed_task_id_map_ct.values(): |
| 599 | + tracker_array.extend(ct_list) |
| 600 | + |
| 601 | + # return all trackers |
| 602 | + return tracker_array |
| 603 | + |
| 604 | + |
462 | 605 | class VerlRolloutManager(DynamicRolloutManager): |
463 | 606 | """High-level manager orchestrating rollouts and batch conversion.""" |
464 | 607 |
|
|
0 commit comments