@@ -86,6 +86,14 @@ def __init__(
8686 # Progress bars (initialized in run())
8787 self .rollout_pbar : Optional [async_tqdm ] = None
8888 self .eval_pbar : Optional [async_tqdm ] = None
89+
90+ # Track active rollouts: {row_index: set of run_indices currently in progress}
91+ self .active_rollouts : Dict [int , set ] = defaultdict (set )
92+ self .active_rollouts_lock = asyncio .Lock ()
93+
94+ # Track active evaluations
95+ self .active_evals : int = 0
96+ self .active_evals_lock = asyncio .Lock ()
8997
9098 async def schedule_dataset (
9199 self ,
@@ -140,51 +148,64 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
140148 run_id = rows_to_eval [0 ].execution_metadata .run_id if isinstance (rows_to_eval , list ) else rows_to_eval .execution_metadata .run_id
141149 eval_res = None
142150
151+ # Track active eval
152+ async with self .active_evals_lock :
153+ self .active_evals += 1
154+ if self .eval_pbar :
155+ self .eval_pbar .set_postfix_str (f"active={ self .active_evals } " )
156+
143157 start_time = time .perf_counter ()
144158
145- async with self .eval_sem :
146- async with rollout_logging_context (
147- rollout_id or "" ,
148- experiment_id = experiment_id ,
149- run_id = run_id ,
150- ):
151- if isinstance (rows_to_eval , list ):
152- eval_res = await execute_pytest_with_exception_handling (
153- test_func = self .eval_executor ,
154- evaluation_test_kwargs = self .evaluation_test_kwargs ,
155- processed_dataset = rows_to_eval ,
156- )
159+ try :
160+ async with self .eval_sem :
161+ async with rollout_logging_context (
162+ rollout_id or "" ,
163+ experiment_id = experiment_id ,
164+ run_id = run_id ,
165+ ):
166+ if isinstance (rows_to_eval , list ):
167+ eval_res = await execute_pytest_with_exception_handling (
168+ test_func = self .eval_executor ,
169+ evaluation_test_kwargs = self .evaluation_test_kwargs ,
170+ processed_dataset = rows_to_eval ,
171+ )
172+ else :
173+ eval_res = await execute_pytest_with_exception_handling (
174+ test_func = self .eval_executor ,
175+ evaluation_test_kwargs = self .evaluation_test_kwargs ,
176+ processed_row = rows_to_eval ,
177+ )
178+ eval_duration = time .perf_counter () - start_time
179+ # push result to the output buffer
180+ if self .output_buffer :
181+ if isinstance (eval_res , list ):
182+ for row in eval_res :
183+ self ._post_process_result (row )
184+ await self .output_buffer .add_result (row )
157185 else :
158- eval_res = await execute_pytest_with_exception_handling (
159- test_func = self .eval_executor ,
160- evaluation_test_kwargs = self .evaluation_test_kwargs ,
161- processed_row = rows_to_eval ,
162- )
163- eval_duration = time .perf_counter () - start_time
164- # push result to the output buffer
165- if self .output_buffer :
186+ self ._post_process_result (eval_res )
187+ await self .output_buffer .add_result (eval_res )
188+
166189 if isinstance (eval_res , list ):
167190 for row in eval_res :
168- self ._post_process_result (row )
169- await self .output_buffer .add_result (row )
191+ row .execution_metadata .eval_duration_seconds = eval_duration
192+ self .results .append (row )
193+ # Update eval progress bar (groupwise: 1 eval for the group)
194+ if self .eval_pbar :
195+ self .eval_pbar .update (1 )
170196 else :
171- self ._post_process_result (eval_res )
172- await self .output_buffer .add_result (eval_res )
173-
174- if isinstance (eval_res , list ):
175- for row in eval_res :
176- row .execution_metadata .eval_duration_seconds = eval_duration
177- self .results .append (row )
178- # Update eval progress bar (groupwise: 1 eval for the group)
179- if self .eval_pbar :
180- self .eval_pbar .update (1 )
181- else :
182- eval_res .execution_metadata .eval_duration_seconds = eval_duration
183- self .results .append (eval_res )
184- # Update eval progress bar (pointwise: 1 eval per row)
185- if self .eval_pbar :
186- self .eval_pbar .update (1 )
187- return eval_res
197+ eval_res .execution_metadata .eval_duration_seconds = eval_duration
198+ self .results .append (eval_res )
199+ # Update eval progress bar (pointwise: 1 eval per row)
200+ if self .eval_pbar :
201+ self .eval_pbar .update (1 )
202+ return eval_res
203+ finally :
204+ # Decrement active eval counter
205+ async with self .active_evals_lock :
206+ self .active_evals -= 1
207+ if self .eval_pbar :
208+ self .eval_pbar .set_postfix_str (f"active={ self .active_evals } " )
188209
189210 # 1. Prepare Config & Row for this micro-batch
190211 current_batch_rows = []
@@ -223,23 +244,36 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
223244 batch_results : List [EvaluationRow ] = []
224245 if current_batch_rows :
225246 for idx , row in current_batch_rows :
247+ # Track this rollout as active
248+ async with self .active_rollouts_lock :
249+ self .active_rollouts [task .row_index ].add (idx )
250+ await self ._update_rollout_pbar_postfix ()
251+
226252 start_time = time .perf_counter ()
227- async for result_row in rollout_processor_with_retry (
228- self .rollout_processor , [row ], task .config , idx , disable_tqdm = True
229- ):
230- rollout_duration = time .perf_counter () - start_time
231- result_row .execution_metadata .rollout_duration_seconds = rollout_duration
232- batch_results .append (result_row )
233-
234- # Update rollout progress bar
235- if self .rollout_pbar :
236- self .rollout_pbar .update (1 )
237-
238- # in pointwise, we start evaluation immediately
239- if self .mode == "pointwise" :
240- t = asyncio .create_task (_run_eval (result_row ))
241- self .background_tasks .add (t )
242- t .add_done_callback (self .background_tasks .discard )
253+ try :
254+ async for result_row in rollout_processor_with_retry (
255+ self .rollout_processor , [row ], task .config , idx , disable_tqdm = True
256+ ):
257+ rollout_duration = time .perf_counter () - start_time
258+ result_row .execution_metadata .rollout_duration_seconds = rollout_duration
259+ batch_results .append (result_row )
260+
261+ # Update rollout progress bar
262+ if self .rollout_pbar :
263+ self .rollout_pbar .update (1 )
264+
265+ # in pointwise, we start evaluation immediately
266+ if self .mode == "pointwise" :
267+ t = asyncio .create_task (_run_eval (result_row ))
268+ self .background_tasks .add (t )
269+ t .add_done_callback (self .background_tasks .discard )
270+ finally :
271+ # Remove from active tracking
272+ async with self .active_rollouts_lock :
273+ self .active_rollouts [task .row_index ].discard (idx )
274+ if not self .active_rollouts [task .row_index ]:
275+ del self .active_rollouts [task .row_index ]
276+ await self ._update_rollout_pbar_postfix ()
243277
244278 # 3. Evaluate and Collect History
245279 current_batch_history_updates = []
@@ -283,6 +317,34 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
283317 )
284318 self .queue .put_nowait (new_task )
285319
320+ def _format_active_rollouts (self ) -> str :
321+ """Format active rollouts for display in progress bar."""
322+ if not self .active_rollouts :
323+ return ""
324+
325+ # Show active rows and their run indices
326+ parts = []
327+ for row_idx in sorted (self .active_rollouts .keys ())[:5 ]: # Limit to 5 rows to keep it readable
328+ runs = sorted (self .active_rollouts [row_idx ])
329+ if runs :
330+ runs_str = "," .join (str (r ) for r in runs [:3 ]) # Show up to 3 run indices
331+ if len (runs ) > 3 :
332+ runs_str += f"+{ len (runs )- 3 } "
333+ parts .append (f"r{ row_idx } :[{ runs_str } ]" )
334+
335+ if len (self .active_rollouts ) > 5 :
336+ parts .append (f"+{ len (self .active_rollouts )- 5 } more" )
337+
338+ return " | " .join (parts )
339+
340+ async def _update_rollout_pbar_postfix (self ):
341+ """Update the rollout progress bar postfix with active tasks info."""
342+ if self .rollout_pbar :
343+ active_count = sum (len (runs ) for runs in self .active_rollouts .values ())
344+ self .rollout_pbar .set_postfix_str (
345+ f"active={ active_count } { self ._format_active_rollouts ()} "
346+ )
347+
286348 def _post_process_result (self , res : EvaluationRow ):
287349 """
288350 Process evaluation result: update cost metrics, status, and log.
0 commit comments