Skip to content

Commit bfbe923

Browse files
committed
update rollout result when stop_condition_callback
1 parent df222cd commit bfbe923

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

ajet/task_rollout/native_parallel_worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def rollout_swarm( # noqa: C901
187187
completed_task_id_map_ct: Dict[str, List[BaseContextTracker]] = {}
188188
executor_lock = threading.Lock()
189189

190+
# count tasks to see whether we have reach the finish line for next weight update
190191
def count_tasks(completed_task_id_map_ct):
191192
total_completed_episodes = 0
192193
total_completed_tasks = 0
@@ -290,7 +291,9 @@ def stop_condition_callback(completed_task_id_map_ct):
290291
if stop_condition(completed_task_id_map_ct):
291292
if not is_already_soft_stopped():
292293
stop_all_threads_soft()
294+
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
293295
return True
296+
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
294297
return False
295298

296299
# submit initial tasks
@@ -388,7 +391,6 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma
388391
tracker.current_batch_success_rate = float(task_success_rate)
389392
tracker.current_batch_reward = float(task_scalar_reward)
390393

391-
# for debugging
392394
update_rollout_result_array_preview(observation_window, completed_task_id_map_ct)
393395
self._write_swarm_rollout_dynamic_log(observation_window)
394396

0 commit comments

Comments
 (0)