Skip to content

Commit 9ceb154

Browse files
committed
tqdm for scheduler
1 parent fb9c58c commit 9ceb154

1 file changed

Lines changed: 117 additions & 55 deletions

File tree

eval_protocol/pytest/priority_scheduler.py

Lines changed: 117 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ def __init__(
8686
# Progress bars (initialized in run())
8787
self.rollout_pbar: Optional[async_tqdm] = None
8888
self.eval_pbar: Optional[async_tqdm] = None
89+
90+
# Track active rollouts: {row_index: set of run_indices currently in progress}
91+
self.active_rollouts: Dict[int, set] = defaultdict(set)
92+
self.active_rollouts_lock = asyncio.Lock()
93+
94+
# Track active evaluations
95+
self.active_evals: int = 0
96+
self.active_evals_lock = asyncio.Lock()
8997

9098
async def schedule_dataset(
9199
self,
@@ -140,51 +148,64 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
140148
run_id = rows_to_eval[0].execution_metadata.run_id if isinstance(rows_to_eval, list) else rows_to_eval.execution_metadata.run_id
141149
eval_res = None
142150

151+
# Track active eval
152+
async with self.active_evals_lock:
153+
self.active_evals += 1
154+
if self.eval_pbar:
155+
self.eval_pbar.set_postfix_str(f"active={self.active_evals}")
156+
143157
start_time = time.perf_counter()
144158

145-
async with self.eval_sem:
146-
async with rollout_logging_context(
147-
rollout_id or "",
148-
experiment_id=experiment_id,
149-
run_id=run_id,
150-
):
151-
if isinstance(rows_to_eval, list):
152-
eval_res = await execute_pytest_with_exception_handling(
153-
test_func=self.eval_executor,
154-
evaluation_test_kwargs=self.evaluation_test_kwargs,
155-
processed_dataset=rows_to_eval,
156-
)
159+
try:
160+
async with self.eval_sem:
161+
async with rollout_logging_context(
162+
rollout_id or "",
163+
experiment_id=experiment_id,
164+
run_id=run_id,
165+
):
166+
if isinstance(rows_to_eval, list):
167+
eval_res = await execute_pytest_with_exception_handling(
168+
test_func=self.eval_executor,
169+
evaluation_test_kwargs=self.evaluation_test_kwargs,
170+
processed_dataset=rows_to_eval,
171+
)
172+
else:
173+
eval_res = await execute_pytest_with_exception_handling(
174+
test_func=self.eval_executor,
175+
evaluation_test_kwargs=self.evaluation_test_kwargs,
176+
processed_row=rows_to_eval,
177+
)
178+
eval_duration = time.perf_counter() - start_time
179+
# push result to the output buffer
180+
if self.output_buffer:
181+
if isinstance(eval_res, list):
182+
for row in eval_res:
183+
self._post_process_result(row)
184+
await self.output_buffer.add_result(row)
157185
else:
158-
eval_res = await execute_pytest_with_exception_handling(
159-
test_func=self.eval_executor,
160-
evaluation_test_kwargs=self.evaluation_test_kwargs,
161-
processed_row=rows_to_eval,
162-
)
163-
eval_duration = time.perf_counter() - start_time
164-
# push result to the output buffer
165-
if self.output_buffer:
186+
self._post_process_result(eval_res)
187+
await self.output_buffer.add_result(eval_res)
188+
166189
if isinstance(eval_res, list):
167190
for row in eval_res:
168-
self._post_process_result(row)
169-
await self.output_buffer.add_result(row)
191+
row.execution_metadata.eval_duration_seconds = eval_duration
192+
self.results.append(row)
193+
# Update eval progress bar (groupwise: 1 eval for the group)
194+
if self.eval_pbar:
195+
self.eval_pbar.update(1)
170196
else:
171-
self._post_process_result(eval_res)
172-
await self.output_buffer.add_result(eval_res)
173-
174-
if isinstance(eval_res, list):
175-
for row in eval_res:
176-
row.execution_metadata.eval_duration_seconds = eval_duration
177-
self.results.append(row)
178-
# Update eval progress bar (groupwise: 1 eval for the group)
179-
if self.eval_pbar:
180-
self.eval_pbar.update(1)
181-
else:
182-
eval_res.execution_metadata.eval_duration_seconds = eval_duration
183-
self.results.append(eval_res)
184-
# Update eval progress bar (pointwise: 1 eval per row)
185-
if self.eval_pbar:
186-
self.eval_pbar.update(1)
187-
return eval_res
197+
eval_res.execution_metadata.eval_duration_seconds = eval_duration
198+
self.results.append(eval_res)
199+
# Update eval progress bar (pointwise: 1 eval per row)
200+
if self.eval_pbar:
201+
self.eval_pbar.update(1)
202+
return eval_res
203+
finally:
204+
# Decrement active eval counter
205+
async with self.active_evals_lock:
206+
self.active_evals -= 1
207+
if self.eval_pbar:
208+
self.eval_pbar.set_postfix_str(f"active={self.active_evals}")
188209

189210
# 1. Prepare Config & Row for this micro-batch
190211
current_batch_rows = []
@@ -223,23 +244,36 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
223244
batch_results: List[EvaluationRow] = []
224245
if current_batch_rows:
225246
for idx, row in current_batch_rows:
247+
# Track this rollout as active
248+
async with self.active_rollouts_lock:
249+
self.active_rollouts[task.row_index].add(idx)
250+
await self._update_rollout_pbar_postfix()
251+
226252
start_time = time.perf_counter()
227-
async for result_row in rollout_processor_with_retry(
228-
self.rollout_processor, [row], task.config, idx, disable_tqdm=True
229-
):
230-
rollout_duration = time.perf_counter() - start_time
231-
result_row.execution_metadata.rollout_duration_seconds = rollout_duration
232-
batch_results.append(result_row)
233-
234-
# Update rollout progress bar
235-
if self.rollout_pbar:
236-
self.rollout_pbar.update(1)
237-
238-
# in pointwise, we start evaluation immediately
239-
if self.mode == "pointwise":
240-
t = asyncio.create_task(_run_eval(result_row))
241-
self.background_tasks.add(t)
242-
t.add_done_callback(self.background_tasks.discard)
253+
try:
254+
async for result_row in rollout_processor_with_retry(
255+
self.rollout_processor, [row], task.config, idx, disable_tqdm=True
256+
):
257+
rollout_duration = time.perf_counter() - start_time
258+
result_row.execution_metadata.rollout_duration_seconds = rollout_duration
259+
batch_results.append(result_row)
260+
261+
# Update rollout progress bar
262+
if self.rollout_pbar:
263+
self.rollout_pbar.update(1)
264+
265+
# in pointwise, we start evaluation immediately
266+
if self.mode == "pointwise":
267+
t = asyncio.create_task(_run_eval(result_row))
268+
self.background_tasks.add(t)
269+
t.add_done_callback(self.background_tasks.discard)
270+
finally:
271+
# Remove from active tracking
272+
async with self.active_rollouts_lock:
273+
self.active_rollouts[task.row_index].discard(idx)
274+
if not self.active_rollouts[task.row_index]:
275+
del self.active_rollouts[task.row_index]
276+
await self._update_rollout_pbar_postfix()
243277

244278
# 3. Evaluate and Collect History
245279
current_batch_history_updates = []
@@ -283,6 +317,34 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
283317
)
284318
self.queue.put_nowait(new_task)
285319

320+
def _format_active_rollouts(self) -> str:
321+
"""Format active rollouts for display in progress bar."""
322+
if not self.active_rollouts:
323+
return ""
324+
325+
# Show active rows and their run indices
326+
parts = []
327+
for row_idx in sorted(self.active_rollouts.keys())[:5]: # Limit to 5 rows to keep it readable
328+
runs = sorted(self.active_rollouts[row_idx])
329+
if runs:
330+
runs_str = ",".join(str(r) for r in runs[:3]) # Show up to 3 run indices
331+
if len(runs) > 3:
332+
runs_str += f"+{len(runs)-3}"
333+
parts.append(f"r{row_idx}:[{runs_str}]")
334+
335+
if len(self.active_rollouts) > 5:
336+
parts.append(f"+{len(self.active_rollouts)-5} more")
337+
338+
return " | ".join(parts)
339+
340+
async def _update_rollout_pbar_postfix(self):
341+
"""Update the rollout progress bar postfix with active tasks info."""
342+
if self.rollout_pbar:
343+
active_count = sum(len(runs) for runs in self.active_rollouts.values())
344+
self.rollout_pbar.set_postfix_str(
345+
f"active={active_count} {self._format_active_rollouts()}"
346+
)
347+
286348
def _post_process_result(self, res: EvaluationRow):
287349
"""
288350
Process evaluation result: update cost metrics, status, and log.

0 commit comments

Comments
 (0)