|
1 | 1 | import asyncio |
2 | 2 | import logging |
3 | 3 | import os |
| 4 | +import time |
4 | 5 | from collections import defaultdict |
5 | 6 | from dataclasses import dataclass, field |
6 | | -from typing import Any, Callable, List, Dict, Optional, Union, Awaitable |
| 7 | +from typing import Any, List, Dict, Optional, Union |
| 8 | + |
| 9 | +from tqdm.asyncio import tqdm as async_tqdm |
7 | 10 |
|
8 | 11 | from eval_protocol.models import EvaluationRow, Status |
9 | 12 | from eval_protocol.pytest.types import RolloutProcessorConfig, TestFunction |
|
14 | 17 | from eval_protocol.human_id import generate_id |
15 | 18 | from eval_protocol.log_utils.rollout_context import rollout_logging_context |
16 | 19 | from eval_protocol.pytest.execution import execute_pytest_with_exception_handling |
17 | | -import time |
18 | 20 |
|
19 | 21 | ENABLE_SPECULATION = os.getenv("ENABLE_SPECULATION", "0").strip() == "1" |
20 | 22 |
|
@@ -80,6 +82,10 @@ def __init__( |
80 | 82 | self.rollout_n = rollout_n |
81 | 83 | self.in_group_minibatch_size = in_group_minibatch_size if in_group_minibatch_size > 0 else rollout_n |
82 | 84 | self.evaluation_test_kwargs = evaluation_test_kwargs |
| 85 | + |
| 86 | + # Progress bars (initialized in run()) |
| 87 | + self.rollout_pbar: Optional[async_tqdm] = None |
| 88 | + self.eval_pbar: Optional[async_tqdm] = None |
83 | 89 |
|
84 | 90 | async def schedule_dataset( |
85 | 91 | self, |
@@ -169,9 +175,15 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): |
169 | 175 | for row in eval_res: |
170 | 176 | row.execution_metadata.eval_duration_seconds = eval_duration |
171 | 177 | self.results.append(row) |
| 178 | + # Update eval progress bar (groupwise: 1 eval for the group) |
| 179 | + if self.eval_pbar: |
| 180 | + self.eval_pbar.update(1) |
172 | 181 | else: |
173 | 182 | eval_res.execution_metadata.eval_duration_seconds = eval_duration |
174 | 183 | self.results.append(eval_res) |
| 184 | + # Update eval progress bar (pointwise: 1 eval per row) |
| 185 | + if self.eval_pbar: |
| 186 | + self.eval_pbar.update(1) |
175 | 187 | return eval_res |
176 | 188 |
|
177 | 189 | # 1. Prepare Config & Row for this micro-batch |
@@ -211,10 +223,18 @@ async def _run_eval(rows_to_eval: Union[EvaluationRow, List[EvaluationRow]]): |
211 | 223 | batch_results: List[EvaluationRow] = [] |
212 | 224 | if current_batch_rows: |
213 | 225 | for idx, row in current_batch_rows: |
| 226 | + start_time = time.perf_counter() |
214 | 227 | async for result_row in rollout_processor_with_retry( |
215 | 228 | self.rollout_processor, [row], task.config, idx, disable_tqdm=True |
216 | 229 | ): |
| 230 | + rollout_duration = time.perf_counter() - start_time |
| 231 | + result_row.execution_metadata.rollout_duration_seconds = rollout_duration |
217 | 232 | batch_results.append(result_row) |
| 233 | + |
| 234 | + # Update rollout progress bar |
| 235 | + if self.rollout_pbar: |
| 236 | + self.rollout_pbar.update(1) |
| 237 | + |
218 | 238 | # in pointwise, we start evaluation immediately |
219 | 239 | if self.mode == "pointwise": |
220 | 240 | t = asyncio.create_task(_run_eval(result_row)) |
@@ -300,28 +320,58 @@ def _post_process_result(self, res: EvaluationRow): |
300 | 320 | async def run(self, dataset: List[EvaluationRow], num_runs: int, base_config: RolloutProcessorConfig): |
301 | 321 | self.num_runs = num_runs |
302 | 322 |
|
303 | | - # 1. Schedule initial tasks |
304 | | - await self.schedule_dataset(dataset, base_config) |
305 | | - |
306 | | - # 2. Start Workers |
307 | | - # If we have separate limits, we need enough workers to saturate both stages |
308 | | - num_workers = self.max_concurrent_rollouts |
309 | | - |
310 | | - workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] |
311 | | - |
312 | | - # 3. Wait for completion |
313 | | - await self.queue.join() |
| 323 | + # Calculate totals for progress bars |
| 324 | + total_rollouts = len(dataset) * num_runs |
| 325 | + # In pointwise mode: 1 eval per rollout; in groupwise mode: 1 eval per dataset row |
| 326 | + total_evals = total_rollouts if self.mode == "pointwise" else len(dataset) |
314 | 327 |
|
315 | | - # Wait for background evaluations to finish |
316 | | - if self.background_tasks: |
317 | | - await asyncio.gather(*self.background_tasks, return_exceptions=True) |
| 328 | + # Initialize progress bars |
| 329 | + self.rollout_pbar = async_tqdm( |
| 330 | + total=total_rollouts, |
| 331 | + desc="🚀 Rollouts", |
| 332 | + unit="row", |
| 333 | + position=0, |
| 334 | + leave=True, |
| 335 | + colour="cyan", |
| 336 | + ) |
| 337 | + self.eval_pbar = async_tqdm( |
| 338 | + total=total_evals, |
| 339 | + desc="📊 Evals", |
| 340 | + unit="eval", |
| 341 | + position=1, |
| 342 | + leave=True, |
| 343 | + colour="green", |
| 344 | + ) |
318 | 345 |
|
319 | | - # 4. Cleanup |
320 | | - for w in workers: |
321 | | - w.cancel() |
322 | | - |
323 | | - if workers: |
324 | | - await asyncio.gather(*workers, return_exceptions=True) |
| 346 | + try: |
| 347 | + # 1. Schedule initial tasks |
| 348 | + await self.schedule_dataset(dataset, base_config) |
| 349 | + |
| 350 | + # 2. Start Workers |
| 351 | + # If we have separate limits, we need enough workers to saturate both stages |
| 352 | + num_workers = self.max_concurrent_rollouts |
| 353 | + |
| 354 | + workers = [asyncio.create_task(self.worker()) for _ in range(num_workers)] |
| 355 | + |
| 356 | + # 3. Wait for completion |
| 357 | + await self.queue.join() |
| 358 | + |
| 359 | + # Wait for background evaluations to finish |
| 360 | + if self.background_tasks: |
| 361 | + await asyncio.gather(*self.background_tasks, return_exceptions=True) |
| 362 | + |
| 363 | + # 4. Cleanup |
| 364 | + for w in workers: |
| 365 | + w.cancel() |
| 366 | + |
| 367 | + if workers: |
| 368 | + await asyncio.gather(*workers, return_exceptions=True) |
| 369 | + finally: |
| 370 | + # Close progress bars |
| 371 | + if self.rollout_pbar: |
| 372 | + self.rollout_pbar.close() |
| 373 | + if self.eval_pbar: |
| 374 | + self.eval_pbar.close() |
325 | 375 |
|
326 | 376 | # Return collected results |
327 | 377 | return self.results |
|
0 commit comments