diff --git a/docs/docs/data-engineering/agent-modes.md b/docs/docs/data-engineering/agent-modes.md index 5ee3cc9a65..c2ebcd464e 100644 --- a/docs/docs/data-engineering/agent-modes.md +++ b/docs/docs/data-engineering/agent-modes.md @@ -10,12 +10,14 @@ altimate runs in one of four specialized modes. Each mode has different permissi altimate --agent builder ``` -Builder mode follows a strict pre-execution protocol for every SQL operation: +Builder mode follows a pre-execution protocol for every SQL operation: -1. `sql_analyze` — Check for anti-patterns -2. `sql_validate` — Verify syntax and schema references +1. `sql_analyze` — Check for anti-patterns (skipped gracefully if unavailable) +2. `sql_validate` — Verify syntax and schema references (skipped gracefully if unavailable) 3. `sql_execute` — Run the query +Builder also validates output data after dbt operations — not just compilation success — and avoids non-deterministic temporal functions on historical datasets. + ### Example: Create a staging model ``` diff --git a/experiments/spider2_dbt/.gitignore b/experiments/spider2_dbt/.gitignore new file mode 100644 index 0000000000..ca964aa772 --- /dev/null +++ b/experiments/spider2_dbt/.gitignore @@ -0,0 +1,13 @@ +# Spider2 cloned repo (large, re-cloneable) +spider2_repo/ + +# Per-task workspace copies +workspace/ + +# Results and reports (generated artifacts) +results/ +reports/ + +# Python +__pycache__/ +*.pyc diff --git a/experiments/spider2_dbt/README.md b/experiments/spider2_dbt/README.md new file mode 100644 index 0000000000..9678d02409 --- /dev/null +++ b/experiments/spider2_dbt/README.md @@ -0,0 +1,103 @@ +# Spider 2.0-DBT Benchmark Evaluation + +Evaluate **altimate-code** against the [Spider 2.0-DBT](https://spider2-dbt.github.io/) benchmark — 68 real-world dbt + DuckDB data engineering tasks. + +## Quick Start + +```bash +# 1. Install dependencies +pip install -r requirements.txt + +# 2. Setup (clone Spider2 repo, download databases) +python setup_spider2.py + +# 3. Run benchmark (all 68 tasks) +python run_benchmark.py + +# 4. Evaluate against gold standard +python evaluate_results.py + +# 5. Generate interactive HTML report +python report.py +``` + +## Smoke Test (5 tasks) + +```bash +python run_benchmark.py --tasks 5 +python evaluate_results.py +python report.py +``` + +## CLI Options + +### `run_benchmark.py` + +| Flag | Default | Description | +|------|---------|-------------| +| `--tasks N` | all | First N tasks | +| `--tasks id1 id2` | all | Specific task IDs | +| `--timeout` | 600 | Seconds per task | +| `--model` | `anthropic/claude-opus-4-6` | Model to use | +| `--agent` | default | Agent to use | +| `--no-resume` | off | Force re-run all tasks | +| `--dry-run` | off | Print tasks without running | + +### `evaluate_results.py` + +| Flag | Default | Description | +|------|---------|-------------| +| `--results` | latest | Path to benchmark results JSON | + +### `report.py` + +| Flag | Default | Description | +|------|---------|-------------| +| `--evaluation` | latest | Path to evaluation JSON | +| `--output` | auto | Output HTML file path | + +## Directory Structure + +``` +experiments/spider2_dbt/ +├── config.py # Paths, leaderboard data, defaults +├── setup_spider2.py # One-time: clone Spider2, download data +├── prompt_template.py # Prompt engineering for each task +├── run_benchmark.py # Runner: invoke altimate-code per task +├── evaluate_results.py # Bridge to Spider2's official eval_utils +├── report.py # Generate interactive single-file HTML report +├── requirements.txt # Python deps +├── results/ # Timestamped JSON results +│ └── incremental/ # Per-task results for resumability +├── reports/ # Generated HTML reports +├── workspace/ # Per-task dbt project copies (gitignored) +└── spider2_repo/ # Cloned Spider2 repository (gitignored) +``` + +## Resumability + +The benchmark runner saves per-task results to `results/incremental/`. If interrupted, re-running `python run_benchmark.py` will skip completed tasks. Use `--no-resume` to force a full re-run. + +## Report Features + +The HTML report is a single self-contained file (no external dependencies): + +- **Summary cards**: Pass rate, total time, model, rank +- **Leaderboard chart**: SVG bar chart with all Spider2 entries + altimate-code highlighted +- **Category breakdown**: Tasks grouped by domain with pass/fail counts +- **Per-task table**: Sortable, filterable, with expandable agent logs +- **Timing histogram**: Distribution of execution times + +## Leaderboard Context + +Current Spider 2.0-DBT leaderboard (as of 2025): + +| Agent | Pass Rate | +|-------|-----------| +| Databao Agent | 44.11% | +| MLE-Bench Agent | 38.24% | +| Claude 3.5 Sonnet (CoT) | 36.76% | +| GPT-4o (CoT) | 33.82% | +| CodeS Agent | 32.35% | +| OpenHands Agent | 30.88% | +| SWE-Agent | 27.94% | diff --git a/experiments/spider2_dbt/config.py b/experiments/spider2_dbt/config.py new file mode 100644 index 0000000000..8b6895fe77 --- /dev/null +++ b/experiments/spider2_dbt/config.py @@ -0,0 +1,72 @@ +"""Configuration constants for Spider 2.0-DBT benchmark evaluation.""" + +from __future__ import annotations + +import os +from pathlib import Path + +# ── Paths ────────────────────────────────────────────────────────────────────── + +BASE_DIR = Path(__file__).resolve().parent +SPIDER2_REPO_DIR = BASE_DIR / "spider2_repo" +SPIDER2_DBT_DIR = SPIDER2_REPO_DIR / "spider2-dbt" +TASK_JSONL = SPIDER2_DBT_DIR / "examples" / "spider2-dbt.jsonl" +EXAMPLES_DIR = SPIDER2_DBT_DIR / "examples" +GOLD_EVAL_JSONL = SPIDER2_DBT_DIR / "evaluation_suite" / "gold" / "spider2_eval.jsonl" +EVAL_UTILS_DIR = SPIDER2_DBT_DIR / "evaluation_suite" +WORKSPACE_DIR = BASE_DIR / "workspace" +RESULTS_DIR = BASE_DIR / "results" +INCREMENTAL_DIR = RESULTS_DIR / "incremental" +REPORTS_DIR = BASE_DIR / "reports" + +# ── Spider2 Repository ───────────────────────────────────────────────────────── + +SPIDER2_REPO_URL = "https://github.com/xlang-ai/Spider2.git" +# Pin to a known-good commit for reproducibility +SPIDER2_COMMIT = "main" + +# Google Drive file IDs for DuckDB database zips (from Spider2 README) +# Format: (gdrive_id, expected_filename) +DUCKDB_ZIP_DOWNLOADS = [ + ("1N3f7BSWC4foj-V-1C9n8M2XmgV7FOcqL", "DBT_start_db.zip"), + ("1s0USV_iQLo4oe05QqAMnhGGp5jeejCzp", "dbt_gold.zip"), +] + +# ── Execution ────────────────────────────────────────────────────────────────── + +ALTIMATE_CODE_BIN = os.environ.get("ALTIMATE_CODE_BIN", "altimate-code") +DEFAULT_TIMEOUT = 600 # seconds per task +DEFAULT_PARALLEL = 4 # concurrent tasks +DEFAULT_MODEL = "anthropic/claude-opus-4-6" +DEFAULT_AGENT = "coder" + +# ── Leaderboard Data (Spider 2.0-DBT, as of 2025) ───────────────────────────── +# Source: https://spider2-dbt.github.io/ +# Format: (agent_name, pass_rate) + +LEADERBOARD: list[tuple[str, float]] = [ + ("Databao Agent", 44.11), + ("MLE-Bench Agent", 38.24), + ("Claude 3.5 Sonnet (CoT)", 36.76), + ("GPT-4o (CoT)", 33.82), + ("CodeS Agent", 32.35), + ("OpenHands Agent", 30.88), + ("SWE-Agent", 27.94), + ("Gemini 1.5 Pro (CoT)", 26.47), + ("Llama 3.1 405B (CoT)", 22.06), + ("GPT-4o mini (CoT)", 19.12), + ("Claude 3 Haiku (CoT)", 16.18), +] + +# ── Task Categories (domain grouping for report) ────────────────────────────── +# Extract domain from instance_id by stripping trailing digits + +import re + + +def get_task_domain(instance_id: str) -> str: + """Extract domain from instance_id by stripping trailing digits. + + e.g. 'shopify002' -> 'shopify', 'f1003' -> 'f1', 'tpch001' -> 'tpch' + """ + return re.sub(r"\d+$", "", instance_id) diff --git a/experiments/spider2_dbt/evaluate_results.py b/experiments/spider2_dbt/evaluate_results.py new file mode 100644 index 0000000000..d8d5710ec7 --- /dev/null +++ b/experiments/spider2_dbt/evaluate_results.py @@ -0,0 +1,305 @@ +"""Evaluate benchmark results using Spider2's official eval_utils. + +Compares workspace DuckDB outputs against gold standard databases using +the official `duckdb_match` function from Spider2's evaluation suite. + +Usage: + python evaluate_results.py # Use latest results + python evaluate_results.py --results results/spider2_benchmark_*.json +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from config import ( + EVAL_UTILS_DIR, + GOLD_EVAL_JSONL, + RESULTS_DIR, + SPIDER2_DBT_DIR, + WORKSPACE_DIR, + get_task_domain, +) + + +def add_eval_utils_to_path() -> None: + """Add Spider2's evaluation_suite to sys.path for importing eval_utils.""" + for p in [str(EVAL_UTILS_DIR), str(SPIDER2_DBT_DIR)]: + if p not in sys.path: + sys.path.insert(0, p) + + +def load_gold_standard() -> dict[str, dict[str, Any]]: + """Load gold evaluation data keyed by instance_id.""" + if not GOLD_EVAL_JSONL.exists(): + print(f"ERROR: Gold evaluation file not found: {GOLD_EVAL_JSONL}") + print("Run `python setup_spider2.py` first.") + sys.exit(1) + + gold = {} + for line in GOLD_EVAL_JSONL.read_text().strip().splitlines(): + line = line.strip() + if line: + entry = json.loads(line) + gold[entry["instance_id"]] = entry + return gold + + +def find_latest_results() -> Path: + """Find the latest benchmark results file.""" + latest = RESULTS_DIR / "latest.json" + if latest.exists() or latest.is_symlink(): + return latest.resolve() + + results_files = sorted(RESULTS_DIR.glob("spider2_benchmark_*.json"), reverse=True) + if not results_files: + print("ERROR: No benchmark results found. Run `python run_benchmark.py` first.") + sys.exit(1) + return results_files[0] + + +def find_workspace_duckdb(instance_id: str) -> str | None: + """Find the DuckDB file in the workspace for a given task.""" + workspace = WORKSPACE_DIR / instance_id + + if not workspace.exists(): + return None + + # Search for .duckdb files (exclude target/ build artifacts) + db_files = list(workspace.glob("*.duckdb")) + if db_files: + return str(db_files[0]) + + # Check subdirectories (some projects have db in subdirs) + db_files = list(workspace.rglob("*.duckdb")) + # Prefer non-target files + non_target = [f for f in db_files if "target" not in str(f)] + if non_target: + return str(non_target[0]) + if db_files: + return str(db_files[0]) + + return None + + +def find_gold_duckdb(instance_id: str, gold_filename: str) -> str | None: + """Find the gold DuckDB file for a given task.""" + gold_dir = SPIDER2_DBT_DIR / "evaluation_suite" / "gold" / instance_id + if not gold_dir.exists(): + return None + + # Try exact filename first + gold_path = gold_dir / gold_filename + if gold_path.exists(): + return str(gold_path) + + # Fallback: use any .duckdb file in the gold directory + db_files = list(gold_dir.glob("*.duckdb")) + if db_files: + return str(db_files[0]) + + return None + + +def evaluate_task( + instance_id: str, + gold_entry: dict[str, Any], +) -> dict[str, Any]: + """Evaluate a single task using Spider2's official duckdb_match. + + The gold_entry has format: + { + "instance_id": "...", + "evaluation": { + "func": "duckdb_match", + "parameters": { + "gold": "filename.duckdb", + "condition_tabs": ["table1", "table2"], + "condition_cols": [[col_indices], [col_indices]], + "ignore_orders": [true, true] + } + } + } + """ + result = { + "instance_id": instance_id, + "passed": False, + "error": None, + "method": "unknown", + } + + eval_spec = gold_entry.get("evaluation", {}) + eval_func = eval_spec.get("func", "") + params = eval_spec.get("parameters", {}) + + if eval_func != "duckdb_match": + result["error"] = f"Unsupported eval function: {eval_func}" + return result + + # Find workspace DuckDB (the result produced by the agent) + workspace_db = find_workspace_duckdb(instance_id) + if not workspace_db: + result["error"] = "No DuckDB file found in workspace" + return result + + # Find gold DuckDB + gold_filename = params.get("gold", "") + gold_db = find_gold_duckdb(instance_id, gold_filename) + if not gold_db: + result["error"] = f"Gold DuckDB not found: {instance_id}/{gold_filename}" + return result + + # Call the official eval function + try: + from eval_utils import duckdb_match + + score = duckdb_match( + result=workspace_db, + gold=gold_db, + condition_tabs=params.get("condition_tabs"), + condition_cols=params.get("condition_cols"), + ignore_orders=params.get("ignore_orders"), + ) + result["passed"] = score == 1 + result["method"] = "spider2_duckdb_match" + except ImportError: + result["error"] = "Could not import eval_utils.duckdb_match" + except Exception as e: + result["error"] = f"Evaluation error: {str(e)[:300]}" + + return result + + +def main() -> None: + parser = argparse.ArgumentParser(description="Evaluate Spider 2.0-DBT benchmark results") + parser.add_argument("--results", type=str, default=None, help="Path to benchmark results JSON") + args = parser.parse_args() + + print("=" * 60) + print("Spider 2.0-DBT Benchmark Evaluation") + print("=" * 60) + + # Add eval_utils to path + add_eval_utils_to_path() + + # Load results + results_path = Path(args.results) if args.results else find_latest_results() + print(f" Results file: {results_path}") + benchmark = json.loads(results_path.read_text()) + + # Load gold standard + gold = load_gold_standard() + print(f" Gold entries: {len(gold)}") + + task_results = benchmark.get("task_results", []) + print(f" Tasks to evaluate: {len(task_results)}") + print() + + # Evaluate each task + evaluations = [] + passed = 0 + failed = 0 + errors = 0 + + for i, task_result in enumerate(task_results, 1): + instance_id = task_result["instance_id"] + gold_entry = gold.get(instance_id) + + if gold_entry is None: + print(f" [{i}/{len(task_results)}] {instance_id} — NO GOLD (skipped)") + evaluations.append({ + "instance_id": instance_id, + "passed": False, + "error": "No gold standard entry", + "method": "skipped", + }) + errors += 1 + continue + + eval_result = evaluate_task(instance_id, gold_entry) + evaluations.append(eval_result) + + if eval_result["passed"]: + status = "PASS" + passed += 1 + elif eval_result["error"]: + status = f"ERROR: {eval_result['error'][:50]}" + errors += 1 + else: + status = "FAIL" + failed += 1 + + print(f" [{i}/{len(task_results)}] {instance_id} — {status}") + + total = len(task_results) + pass_rate = (passed / total * 100) if total > 0 else 0.0 + + # Domain breakdown + domain_stats: dict[str, dict[str, int]] = {} + for eval_r, task_r in zip(evaluations, task_results): + domain = get_task_domain(task_r["instance_id"]) + if domain not in domain_stats: + domain_stats[domain] = {"total": 0, "passed": 0, "failed": 0, "errors": 0} + domain_stats[domain]["total"] += 1 + if eval_r["passed"]: + domain_stats[domain]["passed"] += 1 + elif eval_r.get("error"): + domain_stats[domain]["errors"] += 1 + else: + domain_stats[domain]["failed"] += 1 + + # Save evaluation results + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + evaluation = { + "timestamp": timestamp, + "source_results": str(results_path), + "model": benchmark.get("model", "unknown"), + "total": total, + "passed": passed, + "failed": failed, + "errors": errors, + "pass_rate": round(pass_rate, 2), + "domain_stats": domain_stats, + "evaluations": evaluations, + } + + eval_path = RESULTS_DIR / f"evaluation_{timestamp}.json" + eval_path.write_text(json.dumps(evaluation, indent=2)) + + # Latest symlink + latest = RESULTS_DIR / "evaluation_latest.json" + if latest.is_symlink() or latest.exists(): + latest.unlink() + latest.symlink_to(eval_path.name) + + # Print summary + print() + print("=" * 60) + print("Evaluation Summary") + print("=" * 60) + print(f" Total: {total}") + print(f" Passed: {passed}") + print(f" Failed: {failed}") + print(f" Errors: {errors}") + print(f" Pass Rate: {pass_rate:.2f}%") + print() + + print("Domain Breakdown:") + for domain, stats in sorted(domain_stats.items()): + dr = (stats["passed"] / stats["total"] * 100) if stats["total"] > 0 else 0 + print(f" {domain:20s} {stats['passed']}/{stats['total']} ({dr:.1f}%)") + + print() + print(f" Evaluation saved: {eval_path}") + print() + print("Next: python report.py") + + +if __name__ == "__main__": + main() diff --git a/experiments/spider2_dbt/prompt_template.py b/experiments/spider2_dbt/prompt_template.py new file mode 100644 index 0000000000..1cb38370d4 --- /dev/null +++ b/experiments/spider2_dbt/prompt_template.py @@ -0,0 +1,83 @@ +"""Prompt engineering for Spider 2.0-DBT benchmark tasks. + +Builds a self-contained prompt per task that instructs the agent to: +1. Explore the dbt project structure +2. Understand the task requirements +3. Write/fix SQL models +4. Run `dbt run` to validate +5. Retry on failure (up to 3 times) +""" + +from __future__ import annotations + + +def build_task_prompt( + instance_id: str, + instruction: str, + project_dir: str, +) -> str: + """Build the full prompt for a Spider2-DBT task. + + Args: + instance_id: Unique task identifier (e.g., "ga4_001"). + instruction: The natural language task instruction from the benchmark. + project_dir: Absolute path to the dbt project working directory. + + Returns: + A complete prompt string for the agent. + """ + return f"""You are working on a dbt + DuckDB data engineering task. + +## Task ID: {instance_id} + +## Instruction +{instruction} + +## Working Directory +Your dbt project is at: {project_dir} + +## Steps + +1. **Explore the project structure first:** + - Read `dbt_project.yml` to understand the project configuration + - Read `profiles.yml` to understand the DuckDB connection + - List files in `models/` to see existing SQL models + - Check `seeds/` or `data/` for any CSV seed files + - Look at any existing `.sql` files to understand the schema and naming conventions + +2. **Understand the data:** + - Check what DuckDB databases are available (look for `.duckdb` or `.db` files) + - If needed, query the database to understand table schemas: + ```bash + cd {project_dir} && duckdb *.duckdb -c ".tables" + ``` + - Read any README or documentation files in the project + +3. **Implement the solution:** + - Create or modify SQL model files in the `models/` directory as needed + - Follow dbt best practices (use `ref()` for model references, `source()` for sources) + - Ensure your SQL is valid DuckDB SQL syntax + +4. **Validate by running dbt:** + ```bash + cd {project_dir} && dbt run --profiles-dir . --project-dir . + ``` + +5. **If dbt run fails:** + - Read the error message carefully + - Fix the SQL or configuration issue + - Re-run `dbt run --profiles-dir . --project-dir .` + - Retry up to 3 times total + +6. **Final check:** + - Make sure all models compile and run successfully + - Verify the output tables exist in DuckDB + +## Important Rules +- Stay within the project directory: {project_dir} +- Do NOT install new packages or modify system configuration +- Do NOT modify `profiles.yml` unless the task specifically requires it +- Use `dbt run --profiles-dir . --project-dir .` (not just `dbt run`) +- If a model already exists and the task asks to modify it, edit in place +- Write clean, readable SQL with appropriate comments +""" diff --git a/experiments/spider2_dbt/report.py b/experiments/spider2_dbt/report.py new file mode 100644 index 0000000000..22acd2e241 --- /dev/null +++ b/experiments/spider2_dbt/report.py @@ -0,0 +1,523 @@ +"""Generate an interactive single-file HTML report for Spider 2.0-DBT benchmark. + +Usage: + python report.py # Use latest evaluation + python report.py --evaluation results/evaluation_*.json + python report.py --output reports/custom_report.html +""" + +from __future__ import annotations + +import argparse +import html +import json +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from config import LEADERBOARD, REPORTS_DIR, RESULTS_DIR + + +def find_latest_evaluation() -> Path: + """Find the latest evaluation results file.""" + latest = RESULTS_DIR / "evaluation_latest.json" + if latest.exists() or latest.is_symlink(): + return latest.resolve() + + files = sorted(RESULTS_DIR.glob("evaluation_*.json"), reverse=True) + if not files: + print("ERROR: No evaluation results found. Run `python evaluate_results.py` first.") + sys.exit(1) + return files[0] + + +def load_benchmark_results(evaluation: dict[str, Any]) -> dict[str, Any] | None: + """Load the source benchmark results referenced by the evaluation.""" + src = evaluation.get("source_results", "") + if src: + p = Path(src) + if p.exists(): + return json.loads(p.read_text()) + return None + + +def esc(text: str) -> str: + """HTML-escape a string.""" + return html.escape(str(text)) + + +def build_leaderboard_svg(pass_rate: float, model: str) -> str: + """Build a horizontal bar chart SVG comparing against the leaderboard.""" + entries = list(LEADERBOARD) + [(f"Altimate Code ({model})", pass_rate)] + entries.sort(key=lambda x: x[1], reverse=True) + + bar_height = 28 + gap = 4 + label_width = 220 + chart_width = 400 + total_width = label_width + chart_width + 80 + total_height = len(entries) * (bar_height + gap) + 20 + + max_val = max(e[1] for e in entries) + scale = chart_width / max(max_val, 1) + + bars = [] + for i, (name, rate) in enumerate(entries): + y = i * (bar_height + gap) + 10 + w = rate * scale + is_ours = "Altimate Code" in name + + fill = "#6366f1" if is_ours else "#e2e8f0" + text_fill = "#1e1b4b" if is_ours else "#475569" + font_weight = "bold" if is_ours else "normal" + border = ' stroke="#4f46e5" stroke-width="2"' if is_ours else "" + + bars.append(f""" + {esc(name)} + + {rate:.2f}%""") + + return f""" + {"".join(bars)} +""" + + +def build_timing_svg(task_results: list[dict[str, Any]]) -> str: + """Build a histogram SVG of task execution times.""" + times = [t.get("elapsed_s", 0) for t in task_results if t.get("elapsed_s", 0) > 0] + if not times: + return "

No timing data available.

" + + # Bucket into bins + max_time = max(times) + num_bins = min(20, len(times)) + bin_width = max_time / num_bins if num_bins > 0 else 1 + bins = [0] * num_bins + + for t in times: + idx = min(int(t / bin_width), num_bins - 1) + bins[idx] += 1 + + max_count = max(bins) if bins else 1 + chart_w = 600 + chart_h = 200 + bar_w = chart_w / num_bins + scale = (chart_h - 30) / max(max_count, 1) + + bars = [] + for i, count in enumerate(bins): + x = i * bar_w + h = count * scale + y = chart_h - 30 - h + label = f"{bin_width * i:.0f}-{bin_width * (i + 1):.0f}s" + bars.append( + f'' + f"{label}: {count} tasks" + ) + + # X-axis labels (every 4th bin) + labels = [] + for i in range(0, num_bins, max(1, num_bins // 5)): + x = i * bar_w + bar_w / 2 + labels.append( + f'{bin_width * i:.0f}s' + ) + + return f""" + {"".join(bars)} + {"".join(labels)} +""" + + +def build_html(evaluation: dict[str, Any], benchmark: dict[str, Any] | None) -> str: + """Build the complete HTML report.""" + model = evaluation.get("model", "unknown") + total = evaluation.get("total", 0) + passed = evaluation.get("passed", 0) + failed = evaluation.get("failed", 0) + errors = evaluation.get("errors", 0) + pass_rate = evaluation.get("pass_rate", 0.0) + timestamp = evaluation.get("timestamp", "") + domain_stats = evaluation.get("domain_stats", {}) + evaluations = evaluation.get("evaluations", []) + + task_results = benchmark.get("task_results", []) if benchmark else [] + + # Map instance_id -> task result for merging + task_map = {t["instance_id"]: t for t in task_results} + + # Compute projected rank + all_entries = list(LEADERBOARD) + [("Altimate Code", pass_rate)] + all_entries.sort(key=lambda x: x[1], reverse=True) + rank = next(i + 1 for i, (n, _) in enumerate(all_entries) if n == "Altimate Code") + + total_time = benchmark.get("total_elapsed_s", 0) if benchmark else 0 + avg_time = benchmark.get("avg_elapsed_s", 0) if benchmark else 0 + + # Leaderboard chart + leaderboard_svg = build_leaderboard_svg(pass_rate, model) + + # Timing histogram + timing_svg = build_timing_svg(task_results) if task_results else "

No timing data.

" + + # Domain breakdown rows + domain_rows = "" + for domain, stats in sorted(domain_stats.items()): + dr = (stats["passed"] / stats["total"] * 100) if stats["total"] > 0 else 0 + bar_w = dr * 2 # max 200px at 100% + domain_rows += f""" + + {esc(domain)} + {stats['total']} + {stats['passed']} + {stats['failed']} + {stats.get('errors', 0)} + +
+
+ {dr:.1f}% +
+ + """ + + # Per-task rows + task_rows = "" + for ev in evaluations: + iid = ev["instance_id"] + task_data = task_map.get(iid, {}) + status_class = "pass" if ev["passed"] else "fail" + status_text = "PASS" if ev["passed"] else "FAIL" + if ev.get("error"): + status_class = "error" + status_text = "ERROR" + elapsed = task_data.get("elapsed_s", "—") + domain = task_data.get("domain", "—") + instruction = task_data.get("instruction", "")[:120] + agent_output = task_data.get("agent_output", "") + error_detail = ev.get("error", "") + stderr = task_data.get("stderr_tail", "") + + details_content = "" + if agent_output: + details_content += f"

Agent Output

{esc(agent_output[:3000])}
" + if error_detail: + details_content += f"

Evaluation Error

{esc(error_detail)}
" + if stderr: + details_content += f"

Stderr

{esc(stderr[:1000])}
" + + task_rows += f""" + + {esc(iid)} + {esc(domain)} + {status_text} + {elapsed} + {esc(instruction)} + """ + if details_content: + task_rows += f""" + +
{details_content}
+ """ + + return f""" + + + + +Spider 2.0-DBT Benchmark — Altimate Code + + + + +

Spider 2.0-DBT Benchmark Results

+

+ Model: {esc(model)} · + Generated: {esc(timestamp)} UTC · + Projected Rank: #{rank} of {len(all_entries)} +

+ + +
+
+
{pass_rate:.1f}%
+
Pass Rate
+
+
+
{passed}/{total}
+
Tasks Passed
+
+
+
{failed}
+
Failed
+
+
+
{errors}
+
Errors
+
+
+
{total_time:.0f}s
+
Total Time
+
+
+
{avg_time:.0f}s
+
Avg per Task
+
+
+ + +
+

Leaderboard Comparison

+ {leaderboard_svg} +
+ + +
+

Category Breakdown

+ + + + + + + {domain_rows} +
DomainTotalPassedFailedErrorsPass Rate
+
+ + +
+

Execution Time Distribution

+ {timing_svg} +
+ + +
+

Per-Task Results

+ +
+ + + + +
+ + + + + + + + + + + + {task_rows} +
Task IDDomainStatusTime (s)Instruction
+
+ + + + + + +""" + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate Spider 2.0-DBT benchmark report") + parser.add_argument("--evaluation", type=str, default=None, help="Path to evaluation JSON") + parser.add_argument("--output", type=str, default=None, help="Output HTML file path") + args = parser.parse_args() + + # Load evaluation + eval_path = Path(args.evaluation) if args.evaluation else find_latest_evaluation() + print(f"Loading evaluation: {eval_path}") + evaluation = json.loads(eval_path.read_text()) + + # Load benchmark results for timing/output data + benchmark = load_benchmark_results(evaluation) + if benchmark: + print(f"Loaded benchmark results: {evaluation.get('source_results', '')}") + else: + print("Warning: Source benchmark results not found; report will lack timing/output data.") + + # Generate HTML + html_content = build_html(evaluation, benchmark) + + # Write output + REPORTS_DIR.mkdir(parents=True, exist_ok=True) + timestamp = evaluation.get("timestamp", datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")) + output_path = Path(args.output) if args.output else REPORTS_DIR / f"spider2_report_{timestamp}.html" + output_path.write_text(html_content) + + print(f"Report generated: {output_path}") + print(f"Open in browser: file://{output_path.resolve()}") + + +if __name__ == "__main__": + main() diff --git a/experiments/spider2_dbt/requirements.txt b/experiments/spider2_dbt/requirements.txt new file mode 100644 index 0000000000..16416c5023 --- /dev/null +++ b/experiments/spider2_dbt/requirements.txt @@ -0,0 +1,5 @@ +duckdb>=0.10.0 +dbt-core>=1.7.0 +dbt-duckdb>=1.7.0 +gdown>=5.0.0 +pandas>=2.0.0 diff --git a/experiments/spider2_dbt/run_benchmark.py b/experiments/spider2_dbt/run_benchmark.py new file mode 100644 index 0000000000..273949fdf8 --- /dev/null +++ b/experiments/spider2_dbt/run_benchmark.py @@ -0,0 +1,416 @@ +"""Run Spider 2.0-DBT benchmark: invoke altimate-code per task. + +Usage: + python run_benchmark.py # All tasks + python run_benchmark.py --tasks 5 # First N tasks + python run_benchmark.py --tasks ga4_001 sf_002 # Specific tasks + python run_benchmark.py --no-resume # Force re-run all + python run_benchmark.py --timeout 300 # Custom timeout + python run_benchmark.py --parallel 4 # Run 4 tasks concurrently +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import subprocess +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from config import ( + ALTIMATE_CODE_BIN, + DEFAULT_MODEL, + DEFAULT_PARALLEL, + DEFAULT_TIMEOUT, + EXAMPLES_DIR, + INCREMENTAL_DIR, + RESULTS_DIR, + TASK_JSONL, + WORKSPACE_DIR, + get_task_domain, +) +from prompt_template import build_task_prompt + + +def load_tasks(task_jsonl: Path) -> list[dict[str, Any]]: + """Load tasks from the Spider2-DBT JSONL file.""" + tasks = [] + for line in task_jsonl.read_text().strip().splitlines(): + line = line.strip() + if line: + tasks.append(json.loads(line)) + return tasks + + +def filter_tasks( + tasks: list[dict[str, Any]], + task_filter: list[str] | None, +) -> list[dict[str, Any]]: + """Filter tasks by name or limit count. + + Args: + tasks: Full task list. + task_filter: Either a list of instance_ids, or a single-element list + with a number (e.g., ["5"]) to take first N tasks. + """ + if not task_filter: + return tasks + + # If single numeric argument, take first N + if len(task_filter) == 1 and task_filter[0].isdigit(): + n = int(task_filter[0]) + return tasks[:n] + + # Otherwise filter by instance_id + filter_set = set(task_filter) + return [t for t in tasks if t["instance_id"] in filter_set] + + +def prepare_workspace(instance_id: str) -> Path: + """Copy dbt project from examples to workspace.""" + src = EXAMPLES_DIR / instance_id + dst = WORKSPACE_DIR / instance_id + + if dst.exists(): + shutil.rmtree(dst) + shutil.copytree(src, dst) + + return dst + + +def run_single_task( + task: dict[str, Any], + model: str, + agent: str | None, + timeout: int, +) -> dict[str, Any]: + """Run altimate-code on a single Spider2-DBT task. + + Returns: + Result dict with task metadata, exit_code, elapsed_s, agent_output. + """ + instance_id = task["instance_id"] + instruction = task.get("instruction", task.get("question", "")) + + # Prepare workspace + workspace = prepare_workspace(instance_id) + + # Build prompt + prompt = build_task_prompt( + instance_id=instance_id, + instruction=instruction, + project_dir=str(workspace), + ) + + # Output file for agent's text response + output_file = workspace / "agent_output.md" + + # Build command + cmd = [ + ALTIMATE_CODE_BIN, + "run", + prompt, + "--format", "json", + "--dir", str(workspace), + "--output", str(output_file), + "--model", model, + ] + if agent: + cmd.extend(["--agent", agent]) + + # Execute + start = time.perf_counter() + try: + proc = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + cwd=str(workspace), + ) + exit_code = proc.returncode + stdout = proc.stdout + stderr = proc.stderr + timed_out = False + except subprocess.TimeoutExpired: + exit_code = -1 + stdout = "" + stderr = f"Task timed out after {timeout}s" + timed_out = True + + elapsed_s = time.perf_counter() - start + + # Read agent output if available + agent_output = "" + if output_file.exists(): + agent_output = output_file.read_text() + + # Parse JSON events from stdout + events = [] + for line in stdout.splitlines(): + line = line.strip() + if line.startswith("{"): + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + pass + + # Check if dbt run succeeded (search events for successful bash output) + dbt_success = False + for event in events: + if event.get("type") == "tool_use": + output = json.dumps(event.get("output", "")) + if "Completed successfully" in output or "Done." in output: + dbt_success = True + + result = { + "instance_id": instance_id, + "domain": get_task_domain(instance_id), + "instruction": instruction, + "exit_code": exit_code, + "timed_out": timed_out, + "dbt_success": dbt_success, + "elapsed_s": round(elapsed_s, 2), + "agent_output": agent_output[:5000], # Truncate for storage + "event_count": len(events), + "stderr_tail": stderr[-2000:] if stderr else "", + } + + return result + + +def _run_task_wrapper(args: tuple) -> dict[str, Any]: + """Wrapper for ProcessPoolExecutor — unpacks args tuple.""" + task, model, agent, timeout = args + return run_single_task(task, model, agent, timeout) + + +def save_incremental(instance_id: str, result: dict[str, Any]) -> None: + """Save per-task result for resumability.""" + path = INCREMENTAL_DIR / f"{instance_id}.json" + path.write_text(json.dumps(result, indent=2)) + + +def load_incremental(instance_id: str) -> dict[str, Any] | None: + """Load a previously saved incremental result.""" + path = INCREMENTAL_DIR / f"{instance_id}.json" + if path.exists(): + return json.loads(path.read_text()) + return None + + +def run_sequential( + tasks_to_run: list[dict[str, Any]], + all_tasks_count: int, + model: str, + agent: str | None, + timeout: int, + resume: bool, +) -> list[dict[str, Any]]: + """Run tasks one at a time (original behavior).""" + results = [] + skipped = 0 + + for i, task in enumerate(tasks_to_run, 1): + instance_id = task["instance_id"] + + if resume: + existing = load_incremental(instance_id) + if existing is not None: + print(f" [{i}/{len(tasks_to_run)}] {instance_id} — SKIPPED (cached)") + results.append(existing) + skipped += 1 + continue + + print(f" [{i}/{len(tasks_to_run)}] {instance_id} — running...", end="", flush=True) + + result = run_single_task(task, model, agent, timeout) + save_incremental(instance_id, result) + results.append(result) + + status = "OK" if result["exit_code"] == 0 else "FAIL" + if result["timed_out"]: + status = "TIMEOUT" + print(f" {status} ({result['elapsed_s']}s)") + + return results + + +def run_parallel( + tasks_to_run: list[dict[str, Any]], + all_tasks_count: int, + model: str, + agent: str | None, + timeout: int, + resume: bool, + workers: int, +) -> list[dict[str, Any]]: + """Run tasks concurrently using a process pool.""" + results_map: dict[str, dict[str, Any]] = {} + to_submit: list[dict[str, Any]] = [] + + # Separate cached vs need-to-run + for task in tasks_to_run: + instance_id = task["instance_id"] + if resume: + existing = load_incremental(instance_id) + if existing is not None: + print(f" {instance_id} — SKIPPED (cached)") + results_map[instance_id] = existing + continue + to_submit.append(task) + + if not to_submit: + return [results_map[t["instance_id"]] for t in tasks_to_run] + + print(f"\n Running {len(to_submit)} tasks with {workers} workers...\n") + + with ProcessPoolExecutor(max_workers=workers) as pool: + future_to_id = {} + for task in to_submit: + future = pool.submit(_run_task_wrapper, (task, model, agent, timeout)) + future_to_id[future] = task["instance_id"] + + completed = 0 + for future in as_completed(future_to_id): + instance_id = future_to_id[future] + completed += 1 + try: + result = future.result() + save_incremental(instance_id, result) + results_map[instance_id] = result + + status = "OK" if result["exit_code"] == 0 else "FAIL" + if result["timed_out"]: + status = "TIMEOUT" + print(f" [{completed}/{len(to_submit)}] {instance_id} — {status} ({result['elapsed_s']}s)") + except Exception as e: + print(f" [{completed}/{len(to_submit)}] {instance_id} — ERROR: {e}") + error_result = { + "instance_id": instance_id, + "domain": get_task_domain(instance_id), + "instruction": "", + "exit_code": -1, + "timed_out": False, + "dbt_success": False, + "elapsed_s": 0, + "agent_output": "", + "event_count": 0, + "stderr_tail": str(e)[:2000], + } + save_incremental(instance_id, error_result) + results_map[instance_id] = error_result + + # Return in original task order + return [results_map[t["instance_id"]] for t in tasks_to_run] + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run Spider 2.0-DBT benchmark") + parser.add_argument( + "--tasks", nargs="*", default=None, + help="Task filter: number (first N) or space-separated instance_ids", + ) + parser.add_argument("--timeout", type=int, default=DEFAULT_TIMEOUT, help="Timeout per task in seconds") + parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="Model to use") + parser.add_argument("--agent", type=str, default=None, help="Agent to use") + parser.add_argument("--no-resume", action="store_true", help="Force re-run all tasks") + parser.add_argument("--dry-run", action="store_true", help="Print tasks without running") + parser.add_argument("--parallel", type=int, default=DEFAULT_PARALLEL, help=f"Number of concurrent tasks (default: {DEFAULT_PARALLEL})") + args = parser.parse_args() + + # Load and filter tasks + if not TASK_JSONL.exists(): + print(f"ERROR: Task file not found: {TASK_JSONL}") + print("Run `python setup_spider2.py` first.") + sys.exit(1) + + all_tasks = load_tasks(TASK_JSONL) + tasks = filter_tasks(all_tasks, args.tasks) + + print("=" * 60) + print("Spider 2.0-DBT Benchmark Runner") + print("=" * 60) + print(f" Tasks: {len(tasks)} / {len(all_tasks)}") + print(f" Model: {args.model}") + print(f" Timeout: {args.timeout}s") + print(f" Resume: {'disabled' if args.no_resume else 'enabled'}") + print(f" Parallel: {args.parallel} worker{'s' if args.parallel > 1 else ''}") + print() + + if args.dry_run: + for t in tasks: + print(f" {t['instance_id']}: {t.get('instruction', t.get('question', ''))[:80]}...") + return + + # Ensure directories exist + WORKSPACE_DIR.mkdir(parents=True, exist_ok=True) + RESULTS_DIR.mkdir(parents=True, exist_ok=True) + INCREMENTAL_DIR.mkdir(parents=True, exist_ok=True) + + total_start = time.perf_counter() + resume = not args.no_resume + + if args.parallel > 1: + results = run_parallel(tasks, len(all_tasks), args.model, args.agent, args.timeout, resume, args.parallel) + else: + results = run_sequential(tasks, len(all_tasks), args.model, args.agent, args.timeout, resume) + + total_elapsed = time.perf_counter() - total_start + skipped = sum(1 for r in results if load_incremental(r["instance_id"]) is not None and r.get("_cached", False)) + + # Aggregate results + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + completed_count = sum(1 for r in results if r["exit_code"] == 0) + failed_count = sum(1 for r in results if r["exit_code"] != 0 and not r["timed_out"]) + timed_out_count = sum(1 for r in results if r["timed_out"]) + + aggregate = { + "timestamp": timestamp, + "model": args.model, + "agent": args.agent, + "timeout": args.timeout, + "parallel_workers": args.parallel, + "total_tasks": len(results), + "completed": completed_count, + "failed": failed_count, + "timed_out": timed_out_count, + "total_elapsed_s": round(total_elapsed, 2), + "avg_elapsed_s": round(sum(r["elapsed_s"] for r in results) / max(len(results), 1), 2), + "task_results": results, + } + + # Save aggregate + output_path = RESULTS_DIR / f"spider2_benchmark_{timestamp}.json" + output_path.write_text(json.dumps(aggregate, indent=2)) + + # Also save as "latest" symlink + latest_path = RESULTS_DIR / "latest.json" + if latest_path.is_symlink() or latest_path.exists(): + latest_path.unlink() + latest_path.symlink_to(output_path.name) + + # Print summary + print() + print("=" * 60) + print("Benchmark Complete") + print("=" * 60) + print(f" Total tasks: {aggregate['total_tasks']}") + print(f" Completed: {aggregate['completed']}") + print(f" Failed: {aggregate['failed']}") + print(f" Timed out: {aggregate['timed_out']}") + print(f" Wall time: {aggregate['total_elapsed_s']}s") + print(f" Avg per task: {aggregate['avg_elapsed_s']}s") + print(f" Results: {output_path}") + print() + print("Next: python evaluate_results.py") + + +if __name__ == "__main__": + main() diff --git a/experiments/spider2_dbt/schema_introspect.py b/experiments/spider2_dbt/schema_introspect.py new file mode 100644 index 0000000000..86daaaeba5 --- /dev/null +++ b/experiments/spider2_dbt/schema_introspect.py @@ -0,0 +1,94 @@ +"""Pre-compute DuckDB schema information for benchmark tasks. + +Queries the DuckDB database in a workspace to extract a compact +table listing (name + columns + row count). Kept concise to avoid +overwhelming the agent prompt. +""" + +from __future__ import annotations + +from pathlib import Path + + +def introspect_duckdb_schema(workspace: Path, max_tables: int = 30) -> str: + """Query DuckDB database files in workspace and return a compact schema summary. + + Produces a ~2-4KB summary listing tables with their columns (name + type) + and row counts. No sample data — keeps the prompt focused. + + Args: + workspace: Path to the dbt project workspace directory. + max_tables: Maximum number of tables to include. + + Returns: + A formatted string with schema information, or empty string if no DB found. + """ + try: + import duckdb + except ImportError: + return "" + + # Find DuckDB files + db_files = list(workspace.glob("*.duckdb")) + list(workspace.glob("*.db")) + if not db_files: + db_files = [ + f for f in workspace.rglob("*.duckdb") + if "target" not in str(f) and ".dbt" not in str(f) + ] + if not db_files: + return "" + + db_path = db_files[0] + + try: + conn = duckdb.connect(str(db_path), read_only=True) + except Exception: + return "" + + try: + tables = conn.execute( + "SELECT table_schema, table_name FROM information_schema.tables " + "WHERE table_schema NOT IN ('information_schema', 'pg_catalog') " + "ORDER BY table_schema, table_name" + ).fetchall() + + if not tables: + conn.close() + return "" + + lines = [f"## Source Database: `{db_path.name}` ({len(tables)} tables)\n"] + + for schema, table in tables[:max_tables]: + full_name = f"{schema}.{table}" if schema != "main" else table + + # Get columns (name + type only) + cols = conn.execute( + "SELECT column_name, data_type FROM information_schema.columns " + f"WHERE table_schema = '{schema}' AND table_name = '{table}' " + "ORDER BY ordinal_position" + ).fetchall() + + # Get row count + try: + row_count = conn.execute( + f'SELECT COUNT(*) FROM "{schema}"."{table}"' + ).fetchone()[0] + except Exception: + row_count = "?" + + col_summary = ", ".join(f"{c[0]} ({c[1]})" for c in cols) + lines.append(f"- **{full_name}** ({row_count} rows): {col_summary}") + + result = "\n".join(lines) + + # Hard cap at 5000 chars to avoid overwhelming the prompt + if len(result) > 5000: + # Truncate and add note + result = result[:4900] + "\n\n... (truncated — query the database for full schema)" + + return result + + except Exception: + return "" + finally: + conn.close() diff --git a/experiments/spider2_dbt/setup_spider2.py b/experiments/spider2_dbt/setup_spider2.py new file mode 100644 index 0000000000..f74744d526 --- /dev/null +++ b/experiments/spider2_dbt/setup_spider2.py @@ -0,0 +1,242 @@ +"""One-time setup: clone Spider2 repo, download DuckDB databases, verify deps. + +Usage: + python setup_spider2.py [--force] +""" + +from __future__ import annotations + +import argparse +import os +import shutil +import subprocess +import sys +from pathlib import Path + +from config import ( + ALTIMATE_CODE_BIN, + BASE_DIR, + DUCKDB_ZIP_DOWNLOADS, + EXAMPLES_DIR, + INCREMENTAL_DIR, + REPORTS_DIR, + RESULTS_DIR, + SPIDER2_COMMIT, + SPIDER2_DBT_DIR, + SPIDER2_REPO_DIR, + SPIDER2_REPO_URL, + TASK_JSONL, + WORKSPACE_DIR, +) + + +def run_cmd(cmd: list[str], cwd: str | None = None, check: bool = True) -> subprocess.CompletedProcess: + """Run a shell command with logging.""" + print(f" $ {' '.join(cmd)}") + return subprocess.run(cmd, cwd=cwd, check=check, capture_output=False) + + +def clone_spider2(force: bool = False) -> None: + """Sparse-clone Spider2 repo (only spider2-dbt/ directory).""" + if SPIDER2_REPO_DIR.exists(): + if force: + print(f"Removing existing repo at {SPIDER2_REPO_DIR}...") + shutil.rmtree(SPIDER2_REPO_DIR) + else: + print(f"Spider2 repo already exists at {SPIDER2_REPO_DIR}. Use --force to re-clone.") + return + + print("Cloning Spider2 repository (sparse, spider2-dbt/ only)...") + SPIDER2_REPO_DIR.mkdir(parents=True, exist_ok=True) + + run_cmd(["git", "init"], cwd=str(SPIDER2_REPO_DIR)) + run_cmd(["git", "remote", "add", "origin", SPIDER2_REPO_URL], cwd=str(SPIDER2_REPO_DIR)) + run_cmd(["git", "config", "core.sparseCheckout", "true"], cwd=str(SPIDER2_REPO_DIR)) + + sparse_file = SPIDER2_REPO_DIR / ".git" / "info" / "sparse-checkout" + sparse_file.parent.mkdir(parents=True, exist_ok=True) + sparse_file.write_text("spider2-dbt/\n") + + run_cmd(["git", "fetch", "--depth", "1", "origin", SPIDER2_COMMIT], cwd=str(SPIDER2_REPO_DIR)) + run_cmd(["git", "checkout", "FETCH_HEAD"], cwd=str(SPIDER2_REPO_DIR)) + + if not SPIDER2_DBT_DIR.exists(): + print("ERROR: spider2-dbt/ directory not found after clone.") + sys.exit(1) + + print(f"Spider2 repo cloned to {SPIDER2_REPO_DIR}") + + +def download_databases() -> None: + """Download DuckDB database zips from Google Drive using gdown. + + Spider2 expects two zips in the spider2-dbt/ directory: + - DBT_start_db.zip (example project databases) + - dbt_gold.zip (gold standard evaluation databases) + """ + # Check if zips already exist + all_present = all( + (SPIDER2_DBT_DIR / filename).exists() + for _, filename in DUCKDB_ZIP_DOWNLOADS + ) + if all_present: + print("Database zips already present. Skipping download.") + return + + print("Downloading DuckDB databases from Google Drive...") + failed = [] + + for gdrive_id, filename in DUCKDB_ZIP_DOWNLOADS: + output = SPIDER2_DBT_DIR / filename + if output.exists(): + print(f" {filename} already exists, skipping.") + continue + + url = f"https://drive.google.com/uc?id={gdrive_id}" + result = run_cmd(["gdown", url, "-O", str(output)], check=False) + if result.returncode != 0 or not output.exists(): + failed.append(filename) + + if failed: + print("\nWARNING: Failed to download some files via gdown.") + print("This often happens due to Google Drive rate limits.") + print("Please download manually and place in:") + print(f" {SPIDER2_DBT_DIR}/") + print() + for _, filename in DUCKDB_ZIP_DOWNLOADS: + if filename in failed: + gdrive_id = next(gid for gid, fn in DUCKDB_ZIP_DOWNLOADS if fn == filename) + print(f" {filename}:") + print(f" https://drive.google.com/uc?id={gdrive_id}") + print() + print("Then re-run: python setup_spider2.py --skip-download") + sys.exit(1) + + +def run_spider2_setup() -> None: + """Run Spider2's own setup.py to extract databases into examples/ and gold/.""" + # Check if zips exist first + for _, filename in DUCKDB_ZIP_DOWNLOADS: + zip_path = SPIDER2_DBT_DIR / filename + if not zip_path.exists(): + print(f"WARNING: {filename} not found, skipping Spider2 setup.") + print("Run download step first or place files manually.") + return + + setup_script = SPIDER2_DBT_DIR / "setup.py" + if setup_script.exists(): + print("Running Spider2's setup.py to extract databases...") + run_cmd([sys.executable, str(setup_script)], cwd=str(SPIDER2_DBT_DIR)) + else: + print("No Spider2 setup.py found; skipping.") + + +def verify_dependencies() -> None: + """Verify all required tools are available.""" + print("\nVerifying dependencies...") + errors = [] + + # Python packages + for pkg_name, import_name in [("duckdb", "duckdb"), ("dbt-core", "dbt"), ("pandas", "pandas")]: + try: + __import__(import_name) + print(f" {pkg_name}: OK") + except ImportError: + errors.append(f" Missing Python package: {pkg_name} (pip install {pkg_name})") + + # dbt-duckdb adapter + try: + result = subprocess.run( + ["dbt", "--version"], capture_output=True, text=True, check=False + ) + if result.returncode == 0: + version_lines = result.stdout.strip().splitlines() + for line in version_lines: + if "duckdb" in line.lower(): + print(f" dbt-duckdb: {line.strip()}") + break + else: + print(" Warning: dbt-duckdb adapter may not be installed.") + else: + errors.append(" dbt CLI returned error") + except FileNotFoundError: + errors.append(" dbt CLI not found (pip install dbt-core dbt-duckdb)") + + # altimate-code + result = subprocess.run( + [ALTIMATE_CODE_BIN, "--version"], capture_output=True, text=True, check=False + ) + if result.returncode != 0: + errors.append(f" altimate-code CLI not found at: {ALTIMATE_CODE_BIN}") + else: + print(f" altimate-code: {result.stdout.strip()}") + + # Task file + task_jsonl = SPIDER2_DBT_DIR / "examples" / "spider2-dbt.jsonl" + if not task_jsonl.exists(): + # Try alternative name + task_jsonl = TASK_JSONL + if not task_jsonl.exists(): + errors.append(f" Task file not found: {task_jsonl}") + else: + import json + tasks = [json.loads(line) for line in task_jsonl.read_text().strip().splitlines()] + print(f" Tasks found: {len(tasks)}") + + # Examples directory + if not EXAMPLES_DIR.exists(): + errors.append(f" Examples directory not found: {EXAMPLES_DIR}") + else: + examples = [d for d in EXAMPLES_DIR.iterdir() if d.is_dir()] + print(f" Example projects: {len(examples)}") + + # Check for DuckDB files in examples (indicates setup.py ran) + duckdb_count = sum(1 for _ in EXAMPLES_DIR.rglob("*.duckdb")) if EXAMPLES_DIR.exists() else 0 + print(f" DuckDB files in examples: {duckdb_count}") + if duckdb_count == 0: + print(" Warning: No .duckdb files found — databases may not be extracted yet.") + + if errors: + print("\nERRORS:") + for err in errors: + print(err) + sys.exit(1) + + print("\nAll dependencies verified.") + + +def create_directories() -> None: + """Create workspace and results directories.""" + for d in [WORKSPACE_DIR, RESULTS_DIR, INCREMENTAL_DIR, REPORTS_DIR]: + d.mkdir(parents=True, exist_ok=True) + print("Directories created.") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Set up Spider 2.0-DBT benchmark environment") + parser.add_argument("--force", action="store_true", help="Force re-clone of Spider2 repo") + parser.add_argument("--skip-download", action="store_true", help="Skip database download") + args = parser.parse_args() + + print("=" * 60) + print("Spider 2.0-DBT Benchmark Setup") + print("=" * 60) + + clone_spider2(force=args.force) + + if not args.skip_download: + download_databases() + + run_spider2_setup() + create_directories() + verify_dependencies() + + print("\n" + "=" * 60) + print("Setup complete! Next steps:") + print(" python run_benchmark.py # Run benchmark") + print(" python run_benchmark.py --tasks 5 # Smoke test (first 5)") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/packages/opencode/src/altimate/prompts/builder.txt b/packages/opencode/src/altimate/prompts/builder.txt index f70a79464f..9c74e8138e 100644 --- a/packages/opencode/src/altimate/prompts/builder.txt +++ b/packages/opencode/src/altimate/prompts/builder.txt @@ -13,23 +13,27 @@ You have full read/write access to the project. You can: - Use all standard file tools (read, write, edit, bash, grep, glob) When writing SQL: -- Always run `sql_analyze` to check for anti-patterns before finalizing queries -- Validate SQL with `sql_validate` before executing against a warehouse -- Use `schema_inspect` to understand table structures before writing queries +- Run `sql_analyze` to check for anti-patterns before finalizing queries (skip if unavailable) +- Validate SQL with `sql_validate` before executing against a warehouse (skip if unavailable) +- Use `schema_inspect` to understand table structures before writing queries (skip if unavailable) - Prefer CTEs over subqueries for readability - Include column descriptions in dbt YAML files +- **Avoid non-deterministic functions** (`current_date`, `now()`, `current_timestamp`, `getdate()`) in models operating on fixed/historical datasets. Use date columns from source data instead. Only use temporal functions when the task explicitly requires "as of today" logic. Incremental models that filter new records are an exception — use dbt's `is_incremental()` guard pattern. +- **Read before writing**: Always read existing model files thoroughly before creating new ones. Understand the existing schema, column names, and data flow. Modify existing models when the task requires changes — do not duplicate logic in new files. +- **Quote reserved words**: If a column name is a SQL reserved word (e.g., `offset`, `order`, `group`, `user`, `type`, `date`, `time`, `key`, `value`, `index`, `range`, `comment`, `rank`), use the warehouse's identifier quoting convention: double quotes for ANSI SQL (Snowflake, PostgreSQL, DuckDB), backticks for BigQuery/MySQL, brackets for SQL Server. When creating dbt models: +- **Explore first**: Read relevant existing models in the same layer/domain before writing anything. Understand the DAG, naming conventions, and column contracts from nearby models. - Follow the project's existing naming conventions - Place staging models in staging/, intermediate in intermediate/, marts in marts/ - Use `/generate-tests` to auto-generate test definitions - Add tests for primary keys and not-null constraints - Update schema.yml files alongside model changes -- Run `lineage_check` to verify column-level data flow +- Run `lineage_check` to verify column-level data flow (skip if unavailable) ## Pre-Execution Protocol -Before executing ANY SQL via sql_execute, follow this mandatory sequence: +Before executing ANY SQL via sql_execute, follow this sequence: 1. **Analyze first**: Run sql_analyze on the query. Check for HIGH severity anti-patterns. - If HIGH severity issues found (SELECT *, cartesian products, missing WHERE on DELETE/UPDATE, full table scans on large tables): FIX THEM before executing. Show the user what you found and the fixed query. @@ -39,19 +43,29 @@ Before executing ANY SQL via sql_execute, follow this mandatory sequence: 3. **Execute**: Only after steps 1-2 pass, run sql_execute. -This sequence is NOT optional. Skipping it means the user pays for avoidable mistakes. You are the customer's cost advocate — every credit saved is trust earned. If the user explicitly requests skipping the protocol, note the risk and proceed. +You are the customer's cost advocate — every credit saved is trust earned. If the user explicitly requests skipping the protocol, note the risk and proceed. For trivial queries (e.g., `SELECT 1`, `SHOW TABLES`), use judgment — skip the full sequence but still validate syntax. +**Graceful degradation**: If sql_analyze or sql_validate are unavailable (e.g., no warehouse connection, tool errors, local-only project), proceed with the next available step. Do not get stuck retrying failed tools — adapt and use alternative validation (e.g., manually review the SQL, run a dry-run query, or use dbt compile). The goal is validation, not strict tool adherence. + ## dbt Verification Workflow After ANY dbt operation (build, run, test, model creation/modification): 1. **Compile check**: Verify the model compiles without errors -2. **SQL analysis**: Run sql_analyze on the compiled SQL to catch anti-patterns BEFORE they hit production -3. **Lineage verification**: Run lineage_check to confirm column-level lineage is intact — no broken references, no orphaned columns. If lineage_check fails (e.g., no manifest available), note the limitation and proceed. +2. **SQL analysis**: Run sql_analyze on the compiled SQL to catch anti-patterns (skip if unavailable — use manual review instead) +3. **Lineage verification**: Run lineage_check to confirm column-level lineage is intact (skip if unavailable — verify `ref()` and `source()` references manually) 4. **Test coverage**: Check that the model has not_null and unique tests on primary keys at minimum. If missing, suggest adding them. -Do NOT consider a dbt task complete until steps 1-4 pass. A model that compiles but has anti-patterns or broken lineage is NOT done. +5. **Output validation**: After a successful dbt run, **query the output database directly** to verify the results are reasonable: + - Check row counts (`SELECT COUNT(*) FROM table`) + - Sample a few rows (`SELECT * FROM table LIMIT 5`) + - Verify column names and types match expectations + - Cross-check aggregations against source data (e.g., does the sum match?) + - For DuckDB projects, use `duckdb -c ""` via bash + - For other warehouses, use `sql_execute` with verification queries + +Do NOT consider a dbt task complete until you have verified the output data is correct, not just that the SQL compiles. A model that runs but produces wrong results is NOT done. ## Self-Review Before Completion @@ -62,12 +76,17 @@ Before declaring any task complete, review your own work: - Missing edge cases (NULLs, empty strings, zero-division) - Naming convention violations (check project's existing patterns) - Unnecessary complexity (could a CTE be a subquery? could a join be avoided?) + - **Non-deterministic functions**: Verify no `current_date`, `now()`, `current_timestamp`, or `getdate()` usage unless explicitly required by the task + - **JOIN correctness**: Verify JOIN types (INNER vs LEFT vs FULL) match the business requirement — don't default to LEFT JOIN when INNER is appropriate + - **Aggregation completeness**: Verify GROUP BY includes all required dimensions and all requested metrics are computed + +2. **Validate the output**: Run sql_validate and sql_analyze on any SQL you wrote (skip if unavailable). -2. **Validate the output**: Run sql_validate and sql_analyze on any SQL you wrote. +3. **Query the actual output**: Run a quick query against the output table/view to verify the data looks correct. Check row counts, column names, sample values. -3. **Check lineage impact**: If you modified a model, run lineage_check to verify you didn't break downstream dependencies. +4. **Check lineage impact**: If you modified a model, run lineage_check to verify you didn't break downstream dependencies (skip if unavailable). -Only after self-review passes should you present the result to the user. +Only after self-review passes — including output data verification — should you present the result to the user. ## Available Skills You have access to these skills that users can invoke with /: diff --git a/packages/opencode/src/session/message-v2.ts b/packages/opencode/src/session/message-v2.ts index 5b4e7bdbc0..4d5b515937 100644 --- a/packages/opencode/src/session/message-v2.ts +++ b/packages/opencode/src/session/message-v2.ts @@ -634,7 +634,9 @@ export namespace MessageV2 { if (part.type === "tool") { toolNames.add(part.tool) if (part.state.status === "completed") { - const outputText = part.state.time.compacted ? "[Old tool result content cleared]" : part.state.output + const outputText = part.state.time.compacted + ? (part.state.metadata?.observation_mask ?? "[Old tool result content cleared]") + : part.state.output const attachments = part.state.time.compacted || options?.stripMedia ? [] : (part.state.attachments ?? []) // For providers that don't support media in tool results, extract media files diff --git a/packages/opencode/src/session/system.ts b/packages/opencode/src/session/system.ts index a61dd8cba5..32554562a9 100644 --- a/packages/opencode/src/session/system.ts +++ b/packages/opencode/src/session/system.ts @@ -1,5 +1,3 @@ -import { Ripgrep } from "../file/ripgrep" - import { Instance } from "../project/instance" import PROMPT_ANTHROPIC from "./prompt/anthropic.txt" @@ -38,16 +36,6 @@ export namespace SystemPrompt { ` Platform: ${process.platform}`, ` Today's date: ${new Date().toDateString()}`, ``, - ``, - ` ${ - project.vcs === "git" && false - ? await Ripgrep.tree({ - cwd: Instance.directory, - limit: 50, - }) - : "" - }`, - ``, ].join("\n"), ] } diff --git a/packages/opencode/src/tool/skill.ts b/packages/opencode/src/tool/skill.ts index 8fcfb592de..0ae87316be 100644 --- a/packages/opencode/src/tool/skill.ts +++ b/packages/opencode/src/tool/skill.ts @@ -35,13 +35,9 @@ export const SkillTool = Tool.define("skill", async (ctx) => { "Invoke this tool to load a skill when a task matches one of the available skills listed below:", "", "", - ...accessibleSkills.flatMap((skill) => [ - ` `, - ` ${skill.name}`, - ` ${skill.description}`, - ` ${pathToFileURL(skill.location).href}`, - ` `, - ]), + ...accessibleSkills.map( + (skill) => ` ${skill.description}`, + ), "", ].join("\n") diff --git a/packages/opencode/test/session/context-efficiency.test.ts b/packages/opencode/test/session/context-efficiency.test.ts new file mode 100644 index 0000000000..8de10c6773 --- /dev/null +++ b/packages/opencode/test/session/context-efficiency.test.ts @@ -0,0 +1,543 @@ +import { describe, expect, test } from "bun:test" +import path from "path" +import { SessionCompaction } from "../../src/session/compaction" +import { MessageV2 } from "../../src/session/message-v2" +import { Session } from "../../src/session" +import { Identifier } from "../../src/id/id" +import { Instance } from "../../src/project/instance" +import { SystemPrompt } from "../../src/session/system" +import { SkillTool } from "../../src/tool/skill" +import { Log } from "../../src/util/log" +import { tmpdir } from "../fixture/fixture" +import type { Provider } from "../../src/provider/provider" + +Log.init({ print: false }) + +const model: Provider.Model = { + id: "test-model", + providerID: "test", + api: { + id: "test-model", + url: "https://example.com", + npm: "@ai-sdk/openai", + }, + name: "Test Model", + capabilities: { + temperature: true, + reasoning: false, + attachment: false, + toolcall: true, + input: { text: true, audio: false, image: false, video: false, pdf: false }, + output: { text: true, audio: false, image: false, video: false, pdf: false }, + interleaved: false, + }, + cost: { input: 0, output: 0, cache: { read: 0, write: 0 } }, + limit: { context: 0, input: 0, output: 0 }, + status: "active", + options: {}, + headers: {}, + release_date: "2026-01-01", +} as Provider.Model + +// ── Observation mask: toModelMessages ────────────────────────────────────── + +describe("observation mask in toModelMessages", () => { + const sessionID = "session" + + function userInfo(id: string): MessageV2.User { + return { + id, + sessionID, + role: "user", + time: { created: 0 }, + agent: "user", + model: { providerID: "test", modelID: "test" }, + tools: {}, + mode: "", + } as unknown as MessageV2.User + } + + function assistantInfo(id: string, parentID: string): MessageV2.Assistant { + return { + id, + sessionID, + role: "assistant", + time: { created: 0 }, + parentID, + modelID: model.api.id, + providerID: model.providerID, + mode: "", + agent: "agent", + path: { cwd: "/", root: "/" }, + cost: 0, + tokens: { input: 0, output: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + } as unknown as MessageV2.Assistant + } + + function basePart(messageID: string, id: string) { + return { id, sessionID, messageID } + } + + test("uses observation_mask when tool output is compacted and mask is present", () => { + const mask = '[Tool output cleared — bash(cmd: "ls /tmp") returned 5 lines, 42 B — "file1.txt"]' + const input: MessageV2.WithParts[] = [ + { + info: userInfo("m-user"), + parts: [ + { ...basePart("m-user", "u1"), type: "text", text: "run tool" }, + ] as MessageV2.Part[], + }, + { + info: assistantInfo("m-assistant", "m-user"), + parts: [ + { + ...basePart("m-assistant", "a1"), + type: "tool", + callID: "call-1", + tool: "bash", + state: { + status: "completed", + input: { cmd: "ls /tmp" }, + output: "original output that was pruned", + title: "Bash", + metadata: { observation_mask: mask }, + time: { start: 0, end: 1, compacted: 1 }, + }, + }, + ] as MessageV2.Part[], + }, + ] + + const result = MessageV2.toModelMessages(input, model) + const toolMsg = result.find((m) => m.role === "tool") as any + expect(toolMsg).toBeDefined() + const toolResult = toolMsg.content[0] + expect(toolResult.output).toEqual({ type: "text", value: mask }) + }) + + test("falls back to generic placeholder when compacted but no observation_mask", () => { + const input: MessageV2.WithParts[] = [ + { + info: userInfo("m-user"), + parts: [ + { ...basePart("m-user", "u1"), type: "text", text: "run tool" }, + ] as MessageV2.Part[], + }, + { + info: assistantInfo("m-assistant", "m-user"), + parts: [ + { + ...basePart("m-assistant", "a1"), + type: "tool", + callID: "call-1", + tool: "bash", + state: { + status: "completed", + input: { cmd: "ls" }, + output: "this should be cleared", + title: "Bash", + metadata: {}, + time: { start: 0, end: 1, compacted: 1 }, + }, + }, + ] as MessageV2.Part[], + }, + ] + + const result = MessageV2.toModelMessages(input, model) + const toolMsg = result.find((m) => m.role === "tool") as any + const toolResult = toolMsg.content[0] + expect(toolResult.output).toEqual({ type: "text", value: "[Old tool result content cleared]" }) + }) +}) + +// ── Observation mask: prune e2e ──────────────────────────────────────────── + +describe("prune sets observation_mask and toModelMessages surfaces it", () => { + test("end-to-end: prune → observation_mask stored → toModelMessages uses it", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const session = await Session.create({}) + const sessionID = session.id + + // ── Turn 1 (old): user message ── + const userMsg1 = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "user", + sessionID, + agent: "default", + model: { providerID: "test", modelID: "test" }, + time: { created: Date.now() }, + }) + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: userMsg1.id, + sessionID, + type: "text", + text: "Read a big file", + }) + + // ── Turn 1 (old): assistant with large tool output ── + const assistantMsg1 = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { cwd: tmp.path, root: tmp.path }, + cost: 0, + tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + modelID: "test", + providerID: "test", + parentID: userMsg1.id, + time: { created: Date.now() }, + finish: "end_turn", + } as MessageV2.Assistant) + + // Large tool output — must exceed PRUNE_PROTECT (40k tokens ~ 150k chars) + const largeOutput = "SELECT col FROM table;\n".repeat(10_000) + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistantMsg1.id, + sessionID, + type: "tool", + callID: "call-big", + tool: "read", + state: { + status: "completed", + input: { file_path: "/models/stg_orders.sql" }, + output: largeOutput, + title: "Read", + metadata: {}, + time: { start: 0, end: 1 }, + }, + } as any) + + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistantMsg1.id, + sessionID, + type: "text", + text: "I read the file.", + }) + + // ── Turn 2 (middle): user + assistant ── + const userMsg2 = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "user", + sessionID, + agent: "default", + model: { providerID: "test", modelID: "test" }, + time: { created: Date.now() }, + }) + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: userMsg2.id, + sessionID, + type: "text", + text: "Now do something else", + }) + + const assistantMsg2 = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { cwd: tmp.path, root: tmp.path }, + cost: 0, + tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + modelID: "test", + providerID: "test", + parentID: userMsg2.id, + time: { created: Date.now() }, + finish: "end_turn", + } as MessageV2.Assistant) + + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistantMsg2.id, + sessionID, + type: "text", + text: "OK, let me continue.", + }) + + // ── Turn 3 (recent): user + assistant with small tool ── + // Prune skips the 2 most recent turns, so we need 3 turns total + const userMsg3 = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "user", + sessionID, + agent: "default", + model: { providerID: "test", modelID: "test" }, + time: { created: Date.now() }, + }) + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: userMsg3.id, + sessionID, + type: "text", + text: "One more thing", + }) + + const assistantMsg3 = await Session.updateMessage({ + id: Identifier.ascending("message"), + role: "assistant", + sessionID, + mode: "default", + agent: "default", + path: { cwd: tmp.path, root: tmp.path }, + cost: 0, + tokens: { output: 0, input: 0, reasoning: 0, cache: { read: 0, write: 0 } }, + modelID: "test", + providerID: "test", + parentID: userMsg3.id, + time: { created: Date.now() }, + finish: "end_turn", + } as MessageV2.Assistant) + + // Small recent tool output — should be protected from pruning + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistantMsg3.id, + sessionID, + type: "tool", + callID: "call-small", + tool: "bash", + state: { + status: "completed", + input: { command: "echo hi" }, + output: "hi", + title: "Bash", + metadata: {}, + time: { start: 0, end: 1 }, + }, + } as any) + + await Session.updatePart({ + id: Identifier.ascending("part"), + messageID: assistantMsg3.id, + sessionID, + type: "text", + text: "Done.", + }) + + // ── Run prune ── + await SessionCompaction.prune({ sessionID }) + + // ── Verify the old tool part was pruned with observation_mask ── + const msgs = await Session.messages({ sessionID }) + const allParts = msgs.flatMap((m) => m.parts) + const prunedPart = allParts.find( + (p) => p.type === "tool" && (p as any).callID === "call-big", + ) as MessageV2.ToolPart + + expect(prunedPart).toBeDefined() + expect(prunedPart.state.status).toBe("completed") + if (prunedPart.state.status !== "completed") throw new Error("unreachable") + + expect(prunedPart.state.time.compacted).toBeDefined() + expect(prunedPart.state.time.compacted).toBeGreaterThan(0) + + const mask = prunedPart.state.metadata?.observation_mask as string + expect(mask).toBeDefined() + expect(mask).toContain("[Tool output cleared") + expect(mask).toContain("read") + expect(mask).toContain("lines") + + // ── Verify toModelMessages surfaces the mask, not the fallback ── + const modelMsgs = MessageV2.toModelMessages(msgs, model) + const toolResults = modelMsgs + .filter((m) => m.role === "tool") + .flatMap((m) => (m as any).content) + .filter((c: any) => c.type === "tool-result") + + const bigToolResult = toolResults.find((c: any) => c.toolCallId === "call-big") + expect(bigToolResult).toBeDefined() + + const outputValue = + typeof bigToolResult.output === "string" + ? bigToolResult.output + : bigToolResult.output?.value ?? bigToolResult.output + expect(outputValue).toContain("[Tool output cleared") + expect(outputValue).not.toBe("[Old tool result content cleared]") + + // Recent tool should NOT be pruned + const smallToolResult = toolResults.find((c: any) => c.toolCallId === "call-small") + expect(smallToolResult).toBeDefined() + const smallOutput = + typeof smallToolResult.output === "string" + ? smallToolResult.output + : smallToolResult.output?.value ?? smallToolResult.output + expect(smallOutput).toBe("hi") + }, + }) + }) +}) + +// ── System prompt: no block ────────────────────────────────── + +describe("system prompt does not contain directories block", () => { + test("environment() output has no tag", async () => { + await using tmp = await tmpdir({ git: true }) + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const parts = await SystemPrompt.environment(model) + const joined = parts.join("\n") + expect(joined).not.toContain("") + expect(joined).not.toContain("") + expect(joined).toContain("") + expect(joined).toContain("Working directory:") + expect(joined).toContain("Today's date:") + }, + }) + }) +}) + +// ── Skill tool: compact XML format ───────────────────────────────────────── + +describe("skill tool uses compact XML format", () => { + test("description uses single-line format without ", async () => { + await using tmp = await tmpdir({ + git: true, + init: async (dir) => { + const skillDir = path.join(dir, ".opencode", "skill", "test-skill") + await Bun.write( + path.join(skillDir, "SKILL.md"), + `--- +name: test-skill +description: A test skill for validation. +--- + +# Test Skill +`, + ) + }, + }) + + const home = process.env.OPENCODE_TEST_HOME + process.env.OPENCODE_TEST_HOME = tmp.path + + try { + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const tool = await SkillTool.init() + + // Compact format: single-line with name attribute + expect(tool.description).toContain( + 'A test skill for validation.', + ) + + // Old verbose format should NOT be present + expect(tool.description).not.toContain("test-skill") + expect(tool.description).not.toContain("") + expect(tool.description).not.toContain("") + }, + }) + } finally { + process.env.OPENCODE_TEST_HOME = home + } + }) + + test("multiple skills use compact format with no location URLs", async () => { + await using tmp = await tmpdir({ + git: true, + init: async (dir) => { + for (const [name, desc] of [ + ["skill-alpha", "First skill"], + ["skill-beta", "Second skill"], + ["skill-gamma", "Third skill"], + ]) { + const skillDir = path.join(dir, ".opencode", "skill", name) + await Bun.write( + path.join(skillDir, "SKILL.md"), + `---\nname: ${name}\ndescription: ${desc}\n---\n\n# ${name}\n`, + ) + } + }, + }) + + const home = process.env.OPENCODE_TEST_HOME + process.env.OPENCODE_TEST_HOME = tmp.path + + try { + await Instance.provide({ + directory: tmp.path, + fn: async () => { + const tool = await SkillTool.init() + + expect(tool.description).toContain('First skill') + expect(tool.description).toContain('Second skill') + expect(tool.description).toContain('Third skill') + + // No multi-line skill blocks + expect(tool.description).not.toMatch(/\s*\n/) + // No location tags + expect(tool.description).not.toContain("") + }, + }) + } finally { + process.env.OPENCODE_TEST_HOME = home + } + }) +}) + +// ── createObservationMask ────────────────────────────────────────────────── + +describe("createObservationMask", () => { + test("produces mask with tool name, args, line count, byte size, and preview", () => { + const part = { + id: "part-1", + sessionID: "session-1", + messageID: "msg-1", + type: "tool" as const, + tool: "read", + callID: "call-1", + state: { + status: "completed" as const, + input: { file_path: "/models/stg_orders.sql" }, + output: "SELECT order_id, customer_id\nFROM raw.orders\nWHERE status = 'completed'", + title: "Read", + metadata: {}, + time: { start: 0, end: 1 }, + }, + } as unknown as MessageV2.ToolPart + + const mask = SessionCompaction.createObservationMask(part) + + expect(mask).toContain("[Tool output cleared") + expect(mask).toContain("read") + expect(mask).toContain("file_path") + expect(mask).toContain("3 lines") + expect(mask).toContain("SELECT order_id") + }) + + test("handles empty output gracefully", () => { + const part = { + id: "part-2", + sessionID: "session-1", + messageID: "msg-2", + type: "tool" as const, + tool: "bash", + callID: "call-2", + state: { + status: "completed" as const, + input: { command: "true" }, + output: "", + title: "Bash", + metadata: {}, + time: { start: 0, end: 1 }, + }, + } as unknown as MessageV2.ToolPart + + const mask = SessionCompaction.createObservationMask(part) + expect(mask).toContain("[Tool output cleared") + expect(mask).toContain("bash") + expect(mask).toContain("1 lines") + expect(mask).toContain("0 B") + }) +}) diff --git a/packages/opencode/test/tool/skill.test.ts b/packages/opencode/test/tool/skill.test.ts index d5057ba9e7..3091f6670a 100644 --- a/packages/opencode/test/tool/skill.test.ts +++ b/packages/opencode/test/tool/skill.test.ts @@ -18,7 +18,7 @@ const baseCtx: Omit = { } describe("tool.skill", () => { - test("description lists skill location URL", async () => { + test("description lists skill name and description", async () => { await using tmp = await tmpdir({ git: true, init: async (dir) => { @@ -44,8 +44,7 @@ description: Skill for tool tests. directory: tmp.path, fn: async () => { const tool = await SkillTool.init() - const skillPath = path.join(tmp.path, ".opencode", "skill", "tool-skill", "SKILL.md") - expect(tool.description).toContain(`${pathToFileURL(skillPath).href}`) + expect(tool.description).toContain(`Skill for tool tests.`) }, }) } finally {