|
36 | 36 | import time |
37 | 37 | from dataclasses import dataclass |
38 | 38 | from pathlib import Path |
| 39 | +from typing import Any |
39 | 40 |
|
40 | 41 | from benchmarks.llm_head_to_head import ( |
41 | 42 | cortex_caller, |
|
46 | 47 | ) |
47 | 48 | from benchmarks.llm_head_to_head.data_loader import BeamItem |
48 | 49 | from benchmarks.llm_head_to_head.generator import ( |
| 50 | + GeneratorError, |
| 51 | + GeneratorResponse, |
49 | 52 | PRICING_USD_PER_M_TOKEN, |
| 53 | + call_generator, |
50 | 54 | estimate_cost_usd, |
51 | 55 | ) |
| 56 | +from benchmarks.llm_head_to_head.judge import ( |
| 57 | + JUDGE_FOR_GENERATOR, |
| 58 | + SINGLE_JUDGE_MODEL, |
| 59 | + JudgePanel, |
| 60 | + judge_item, |
| 61 | +) |
52 | 62 | from benchmarks.llm_head_to_head.manifest import ( |
| 63 | + ItemResultLine, |
| 64 | + Manifest, |
53 | 65 | ManifestModelEntry, |
| 66 | + append_item_result, |
54 | 67 | build_manifest, |
| 68 | + update_cost_tracking, |
55 | 69 | write_manifest, |
56 | 70 | ) |
57 | 71 |
|
@@ -262,18 +276,287 @@ def main(argv: list[str] | None = None) -> int: |
262 | 276 | print("[orchestrator] DRY RUN — no API calls made.") |
263 | 277 | return 0 |
264 | 278 |
|
265 | | - # Live-mode wiring deliberately not implemented in Stage 0 (per |
266 | | - # protocol §12 timeline). The pilot.py script runs Stage 1 (B+C |
267 | | - # only on Haiku), and the eventual full-panel Stage 2 builds on |
268 | | - # this orchestrator. Stage 0 stops here. |
| 279 | + # Live-mode CLI is intentionally minimal — pilot.py is the canonical |
| 280 | + # entry point for Stage 1 (B+C on Haiku) and forwards into ``run_live`` |
| 281 | + # below. The orchestrator's CLI here exists only for ad-hoc smoke runs. |
269 | 282 | print( |
270 | | - "[orchestrator] Live mode not yet wired — Stage 0 commits scaffold " |
271 | | - "only. See tasks/beam-10m-llm-head-to-head-protocol.md §12 for " |
272 | | - "the timeline. Use --dry-run for now.", |
| 283 | + "[orchestrator] Live mode CLI not implemented; use " |
| 284 | + "`python -m benchmarks.llm_head_to_head.pilot --run` to fire a " |
| 285 | + "live pilot. The orchestrator library functions (run_live, " |
| 286 | + "build_context) are imported by pilot.py.", |
273 | 287 | file=sys.stderr, |
274 | 288 | ) |
275 | 289 | return 2 |
276 | 290 |
|
277 | 291 |
|
| 292 | +# ── Live runner ────────────────────────────────────────────────────── |
| 293 | + |
| 294 | + |
| 295 | +@dataclass(frozen=True) |
| 296 | +class LiveCellResult: |
| 297 | + """One (item × condition × generator) cell after live execution.""" |
| 298 | + |
| 299 | + question_id: str |
| 300 | + ability: str |
| 301 | + condition: str |
| 302 | + generator_model: str |
| 303 | + generator_response: str |
| 304 | + input_tokens: int |
| 305 | + output_tokens: int |
| 306 | + retry_count: int |
| 307 | + estimated_usd: float |
| 308 | + wall_time_s: float |
| 309 | + |
| 310 | + |
| 311 | +def _generate_one_cell( |
| 312 | + item: BeamItem, |
| 313 | + condition: str, |
| 314 | + generator_model: str, |
| 315 | + answer_template: str, |
| 316 | + db_for_rag: Any, |
| 317 | +) -> LiveCellResult: |
| 318 | + """Build the condition's context, render the prompt, fire one generator call. |
| 319 | +
|
| 320 | + pre: |
| 321 | + - ``condition`` ∈ ALL_CONDITIONS. |
| 322 | + - ``answer_template`` is the contents of ``prompts/answer.md``. |
| 323 | + - For B: ``db_for_rag`` is a BenchmarkDB-like with the BEAM memories |
| 324 | + already loaded under ``domain='beam'``. |
| 325 | + - For C: the production memory store has been seeded with the same |
| 326 | + memories under ``domain='beam'``. |
| 327 | + post: |
| 328 | + - returns one ``LiveCellResult``; raises ``GeneratorError`` if the |
| 329 | + vendor call exhausted retries (so the caller can decide whether |
| 330 | + to skip the cell or abort the run). |
| 331 | + """ |
| 332 | + ctx = build_context(condition, item, generator_model, db_for_rag) |
| 333 | + prompt = render_answer_prompt(answer_template, ctx.text, item.question) |
| 334 | + |
| 335 | + t0 = time.time() |
| 336 | + response: GeneratorResponse = call_generator( |
| 337 | + model_id=generator_model, |
| 338 | + prompt=prompt, |
| 339 | + max_output_tokens=4_000, |
| 340 | + temperature=0.0, |
| 341 | + dry_run=False, |
| 342 | + ) |
| 343 | + wall = time.time() - t0 |
| 344 | + |
| 345 | + cost = estimate_cost_usd( |
| 346 | + generator_model, response.input_tokens, response.output_tokens |
| 347 | + ) |
| 348 | + return LiveCellResult( |
| 349 | + question_id=item.question_id, |
| 350 | + ability=item.ability, |
| 351 | + condition=condition, |
| 352 | + generator_model=generator_model, |
| 353 | + generator_response=response.text, |
| 354 | + input_tokens=response.input_tokens, |
| 355 | + output_tokens=response.output_tokens, |
| 356 | + retry_count=len(response.retries), |
| 357 | + estimated_usd=cost, |
| 358 | + wall_time_s=wall, |
| 359 | + ) |
| 360 | + |
| 361 | + |
| 362 | +def _format_support_for_judge(item: BeamItem) -> str: |
| 363 | + """Render the gold supporting turns for the judge prompt's SUPPORT field. |
| 364 | +
|
| 365 | + pre: ``item`` carries source_chat_ids that index into ``item.turns``. |
| 366 | + post: returns a concatenation of supporting turn texts; empty string |
| 367 | + when ``source_chat_ids`` is empty (abstention items). |
| 368 | + """ |
| 369 | + passages = oracle_loader.build_oracle_context(item) |
| 370 | + return oracle_loader.passages_to_context(passages) |
| 371 | + |
| 372 | + |
| 373 | +def run_live( |
| 374 | + items: list[BeamItem], |
| 375 | + conditions: tuple[str, ...], |
| 376 | + generator_model: str, |
| 377 | + judge_mode: str, |
| 378 | + results_dir: Path, |
| 379 | + answer_template: str, |
| 380 | + judge_template: str, |
| 381 | + db_for_rag: Any, |
| 382 | + cost_ceiling_usd: float, |
| 383 | +) -> dict[str, Any]: |
| 384 | + """End-to-end live run. Builds contexts, generates answers, judges, writes manifest. |
| 385 | +
|
| 386 | + pre: |
| 387 | + - ``items`` is non-empty. |
| 388 | + - ``conditions`` ⊆ ALL_CONDITIONS. |
| 389 | + - ``generator_model`` is in ``VENDOR_BY_MODEL`` and has a configured judge. |
| 390 | + - ``results_dir`` already contains a manifest.json (caller wrote it |
| 391 | + before calling this function); we only append items.jsonl + patch |
| 392 | + cost_tracking. |
| 393 | + - ``cost_ceiling_usd`` is a hard limit; we abort and return early |
| 394 | + with ``{'aborted': True, ...}`` if the running total exceeds it |
| 395 | + (defence-in-depth on Stage 0 budget cap). |
| 396 | + post: |
| 397 | + - returns a summary dict with totals, per-cell results, and judge |
| 398 | + verdicts. |
| 399 | + - one items.jsonl line per (item × condition) is appended. |
| 400 | + - manifest.json's cost_tracking is incremented. |
| 401 | + """ |
| 402 | + manifest_path = results_dir / "manifest.json" |
| 403 | + summary: dict[str, Any] = { |
| 404 | + "items": len(items), |
| 405 | + "conditions": list(conditions), |
| 406 | + "generator": generator_model, |
| 407 | + "judge_mode": judge_mode, |
| 408 | + "total_input_tokens": 0, |
| 409 | + "total_output_tokens": 0, |
| 410 | + "total_usd": 0.0, |
| 411 | + "cells_run": 0, |
| 412 | + "cells_failed": 0, |
| 413 | + "judge_calls": 0, |
| 414 | + "aborted": False, |
| 415 | + } |
| 416 | + |
| 417 | + # Track running total to enforce ``cost_ceiling_usd``. The estimate |
| 418 | + # is conservative (sum of generator + judge cells already completed). |
| 419 | + total_usd = 0.0 |
| 420 | + total_input = 0 |
| 421 | + total_output = 0 |
| 422 | + |
| 423 | + for item_idx, item in enumerate(items, start=1): |
| 424 | + if total_usd > cost_ceiling_usd: |
| 425 | + summary["aborted"] = True |
| 426 | + summary["abort_reason"] = ( |
| 427 | + f"cost_ceiling exceeded after {item_idx - 1} items " |
| 428 | + f"(total ${total_usd:.4f} > ceiling ${cost_ceiling_usd:.4f})" |
| 429 | + ) |
| 430 | + print(f"[orchestrator] {summary['abort_reason']}", file=sys.stderr) |
| 431 | + break |
| 432 | + |
| 433 | + print( |
| 434 | + f"[orchestrator] item {item_idx}/{len(items)} " |
| 435 | + f"({item.question_id}, ability={item.ability})", |
| 436 | + file=sys.stderr, |
| 437 | + ) |
| 438 | + |
| 439 | + candidates_by_condition: dict[str, str] = {} |
| 440 | + cell_results: list[LiveCellResult] = [] |
| 441 | + for cond in conditions: |
| 442 | + try: |
| 443 | + cell = _generate_one_cell( |
| 444 | + item=item, |
| 445 | + condition=cond, |
| 446 | + generator_model=generator_model, |
| 447 | + answer_template=answer_template, |
| 448 | + db_for_rag=db_for_rag, |
| 449 | + ) |
| 450 | + except GeneratorError as e: |
| 451 | + print( |
| 452 | + f"[orchestrator] cell {item.question_id}/{cond} failed: {e}", |
| 453 | + file=sys.stderr, |
| 454 | + ) |
| 455 | + summary["cells_failed"] += 1 |
| 456 | + continue |
| 457 | + |
| 458 | + cell_results.append(cell) |
| 459 | + candidates_by_condition[cond] = cell.generator_response |
| 460 | + total_usd += cell.estimated_usd |
| 461 | + total_input += cell.input_tokens |
| 462 | + total_output += cell.output_tokens |
| 463 | + summary["cells_run"] += 1 |
| 464 | + |
| 465 | + if not candidates_by_condition: |
| 466 | + print( |
| 467 | + f"[orchestrator] all cells for {item.question_id} failed; " |
| 468 | + "skipping judge", |
| 469 | + file=sys.stderr, |
| 470 | + ) |
| 471 | + continue |
| 472 | + |
| 473 | + # Judge: one call, judges all conditions for this item via the |
| 474 | + # cross-vendor pairing in JUDGE_FOR_GENERATOR (or single-judge Opus). |
| 475 | + try: |
| 476 | + panel: JudgePanel = judge_item( |
| 477 | + question_id=item.question_id, |
| 478 | + question=item.question, |
| 479 | + ability=item.ability, |
| 480 | + gold=item.gold_answer or "", |
| 481 | + support=_format_support_for_judge(item), |
| 482 | + candidates_by_condition=candidates_by_condition, |
| 483 | + judge_template=judge_template, |
| 484 | + generator_model_id=generator_model, |
| 485 | + judge_mode=judge_mode, |
| 486 | + dry_run=False, |
| 487 | + ) |
| 488 | + except GeneratorError as e: |
| 489 | + print( |
| 490 | + f"[orchestrator] judge failed for {item.question_id}: {e}", |
| 491 | + file=sys.stderr, |
| 492 | + ) |
| 493 | + # Cells still produced answers — record with judge_label="error" |
| 494 | + # rather than dropping them silently. |
| 495 | + for cell in cell_results: |
| 496 | + _emit_item_line( |
| 497 | + results_dir, |
| 498 | + cell, |
| 499 | + judge_label="error", |
| 500 | + ) |
| 501 | + continue |
| 502 | + |
| 503 | + verdict_by_cond = {v.condition: v.verdict for v in panel.verdicts} |
| 504 | + judge_cost = estimate_cost_usd( |
| 505 | + panel.judge_model, |
| 506 | + panel.judge_response.input_tokens, |
| 507 | + panel.judge_response.output_tokens, |
| 508 | + ) |
| 509 | + total_usd += judge_cost |
| 510 | + total_input += panel.judge_response.input_tokens |
| 511 | + total_output += panel.judge_response.output_tokens |
| 512 | + summary["judge_calls"] += 1 |
| 513 | + |
| 514 | + for cell in cell_results: |
| 515 | + label = verdict_by_cond.get(cell.condition, "incorrect") |
| 516 | + _emit_item_line(results_dir, cell, judge_label=label) |
| 517 | + |
| 518 | + summary["total_input_tokens"] = total_input |
| 519 | + summary["total_output_tokens"] = total_output |
| 520 | + summary["total_usd"] = round(total_usd, 6) |
| 521 | + |
| 522 | + # Patch manifest cost_tracking. |
| 523 | + if manifest_path.exists(): |
| 524 | + update_cost_tracking( |
| 525 | + manifest_path, |
| 526 | + add_input_tokens=total_input, |
| 527 | + add_output_tokens=total_output, |
| 528 | + add_usd=total_usd, |
| 529 | + ) |
| 530 | + return summary |
| 531 | + |
| 532 | + |
| 533 | +def _emit_item_line( |
| 534 | + results_dir: Path, cell: LiveCellResult, judge_label: str |
| 535 | +) -> None: |
| 536 | + """Write one items.jsonl row for a completed (or judge-failed) cell. |
| 537 | +
|
| 538 | + pre: ``judge_label`` is one of the protocol verdicts OR the literal |
| 539 | + ``'error'`` (judge call failed; the cell answer is preserved for audit). |
| 540 | + post: appends one JSONL line; never raises (failures here would mask |
| 541 | + cost-tracking already incremented). |
| 542 | + """ |
| 543 | + append_item_result( |
| 544 | + results_dir, |
| 545 | + ItemResultLine( |
| 546 | + question_id=cell.question_id, |
| 547 | + ability=cell.ability, |
| 548 | + condition=cell.condition, |
| 549 | + generator_model=cell.generator_model, |
| 550 | + generator_response=cell.generator_response, |
| 551 | + judge_label=judge_label, |
| 552 | + input_tokens=cell.input_tokens, |
| 553 | + output_tokens=cell.output_tokens, |
| 554 | + retry_count=cell.retry_count, |
| 555 | + estimated_usd=cell.estimated_usd, |
| 556 | + wall_time_s=cell.wall_time_s, |
| 557 | + ), |
| 558 | + ) |
| 559 | + |
| 560 | + |
278 | 561 | if __name__ == "__main__": |
279 | 562 | sys.exit(main()) |
0 commit comments