Skip to content

Commit fb9c58c

Browse files
committed
tqdm progress bar inside the scheduler
1 parent 25b5ec5 commit fb9c58c

1 file changed

Lines changed: 72 additions & 22 deletions

File tree

eval_protocol/pytest/priority_scheduler.py

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import asyncio
22
import logging
33
import os
4+
import time
45
from collections import defaultdict
56
from dataclasses import dataclass, field
6-
from typing import Any, Callable, List, Dict, Optional, Union, Awaitable
7+
from typing import Any, List, Dict, Optional, Union
8+
9+
from tqdm.asyncio import tqdm as async_tqdm
710

811
from eval_protocol.models import EvaluationRow, Status
912
from eval_protocol.pytest.types import RolloutProcessorConfig, TestFunction
@@ -14,7 +17,6 @@
1417
from eval_protocol.human_id import generate_id
1518
from eval_protocol.log_utils.rollout_context import rollout_logging_context
1619
from eval_protocol.pytest.execution import execute_pytest_with_exception_handling
17-
import time
1820

1921
ENABLE_SPECULATION = os.getenv("ENABLE_SPECULATION", "0").strip() == "1"
2022

@@ -80,6 +82,10 @@ def __init__(
8082
self.rollout_n = rollout_n
8183
self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n
8284
self.evaluation_test_kwargs = evaluation_test_kwargs
85+
86+
# Progress bars (initialized in run())
87+
self.rollout_pbar: Optional[async_tqdm] = None
88+
self.eval_pbar: Optional[async_tqdm] = None
8389

8490
async def schedule_dataset(
8591
self,
@@ -169,9 +175,15 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
169175
for row in eval_res:
170176
row.execution_metadata.eval_duration_seconds = eval_duration
171177
self.results.append(row)
178+
# Update eval progress bar (groupwise: 1 eval for the group)
179+
if self.eval_pbar:
180+
self.eval_pbar.update(1)
172181
else:
173182
eval_res.execution_metadata.eval_duration_seconds = eval_duration
174183
self.results.append(eval_res)
184+
# Update eval progress bar (pointwise: 1 eval per row)
185+
if self.eval_pbar:
186+
self.eval_pbar.update(1)
175187
return eval_res
176188

177189
# 1. Prepare Config & Row for this micro-batch
@@ -211,10 +223,18 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]):
211223
batch_results: List[EvaluationRow] = []
212224
if current_batch_rows:
213225
for idx, row in current_batch_rows:
226+
start_time = time.perf_counter()
214227
async for result_row in rollout_processor_with_retry(
215228
self.rollout_processor, [row], task.config, idx, disable_tqdm=True
216229
):
230+
rollout_duration = time.perf_counter() - start_time
231+
result_row.execution_metadata.rollout_duration_seconds = rollout_duration
217232
batch_results.append(result_row)
233+
234+
# Update rollout progress bar
235+
if self.rollout_pbar:
236+
self.rollout_pbar.update(1)
237+
218238
# in pointwise, we start evaluation immediately
219239
if self.mode == "pointwise":
220240
t = asyncio.create_task(_run_eval(result_row))
@@ -300,28 +320,58 @@ def _post_process_result(self, res: EvaluationRow):
300320
async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: RolloutProcessorConfig):
301321
self.num_runs = num_runs
302322

303-
# 1. Schedule initial tasks
304-
await self.schedule_dataset(dataset, base_config)
305-
306-
# 2. Start Workers
307-
# If we have separate limits, we need enough workers to saturate both stages
308-
num_workers = self.max_concurrent_rollouts
309-
310-
workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)]
311-
312-
# 3. Wait for completion
313-
await self.queue.join()
323+
# Calculate totals for progress bars
324+
total_rollouts = len(dataset) * num_runs
325+
# In pointwise mode: 1 eval per rollout; in groupwise mode: 1 eval per dataset row
326+
total_evals = total_rollouts if self.mode == "pointwise" else len(dataset)
314327

315-
# Wait for background evaluations to finish
316-
if self.background_tasks:
317-
await asyncio.gather(*self.background_tasks, return_exceptions=True)
328+
# Initialize progress bars
329+
self.rollout_pbar = async_tqdm(
330+
total=total_rollouts,
331+
desc="🚀 Rollouts",
332+
unit="row",
333+
position=0,
334+
leave=True,
335+
colour="cyan",
336+
)
337+
self.eval_pbar = async_tqdm(
338+
total=total_evals,
339+
desc="📊 Evals",
340+
unit="eval",
341+
position=1,
342+
leave=True,
343+
colour="green",
344+
)
318345

319-
# 4. Cleanup
320-
for w in workers:
321-
w.cancel()
322-
323-
if workers:
324-
await asyncio.gather(*workers, return_exceptions=True)
346+
try:
347+
# 1. Schedule initial tasks
348+
await self.schedule_dataset(dataset, base_config)
349+
350+
# 2. Start Workers
351+
# If we have separate limits, we need enough workers to saturate both stages
352+
num_workers = self.max_concurrent_rollouts
353+
354+
workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)]
355+
356+
# 3. Wait for completion
357+
await self.queue.join()
358+
359+
# Wait for background evaluations to finish
360+
if self.background_tasks:
361+
await asyncio.gather(*self.background_tasks, return_exceptions=True)
362+
363+
# 4. Cleanup
364+
for w in workers:
365+
w.cancel()
366+
367+
if workers:
368+
await asyncio.gather(*workers, return_exceptions=True)
369+
finally:
370+
# Close progress bars
371+
if self.rollout_pbar:
372+
self.rollout_pbar.close()
373+
if self.eval_pbar:
374+
self.eval_pbar.close()
325375

326376
# Return collected results
327377
return self.results

0 commit comments

Comments
 (0)