|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Run search-only sweeps on an existing ArcadeDB MSMARCO database to compare |
| 4 | +different `overquery_factor` values without rebuilding the index. |
| 5 | +
|
| 6 | +For each factor, we: |
| 7 | +- load the 1,000 query vectors and ground-truth labels |
| 8 | +- open the existing DB |
| 9 | +- warm up once |
| 10 | +- run a full search pass (recall + latency) |
| 11 | +- write results.json / results.md under an output directory that the |
| 12 | + existing summarizer can consume. |
| 13 | +
|
| 14 | +Place the outputs under `arcadedb_runs/*/results.json` (default) so |
| 15 | +`summarize_arcadedb_msmarco.py` will include them in its markdown tables. |
| 16 | +""" |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import argparse |
| 20 | +import json |
| 21 | +from pathlib import Path |
| 22 | +from typing import Dict, List |
| 23 | + |
| 24 | +import arcadedb_embedded as arcadedb |
| 25 | +import numpy as np |
| 26 | +from benchmark_arcadedb_msmarco import ( |
| 27 | + dir_size_mb, |
| 28 | + load_ground_truth, |
| 29 | + load_queries, |
| 30 | + materialize_queries, |
| 31 | + resolve_dataset, |
| 32 | + rss_mb, |
| 33 | + search_index, |
| 34 | + timed_section, |
| 35 | + warmup, |
| 36 | +) |
| 37 | + |
| 38 | + |
| 39 | +def parse_overqueries(raw: str) -> List[int]: |
| 40 | + vals: List[int] = [] |
| 41 | + for part in raw.split(","): |
| 42 | + part = part.strip() |
| 43 | + if not part: |
| 44 | + continue |
| 45 | + try: |
| 46 | + v = int(part) |
| 47 | + except ValueError: |
| 48 | + raise SystemExit(f"Invalid overquery value: {part}") from None |
| 49 | + if v <= 0: |
| 50 | + raise SystemExit("overquery values must be positive") |
| 51 | + vals.append(v) |
| 52 | + if not vals: |
| 53 | + raise SystemExit("No overquery values provided") |
| 54 | + return vals |
| 55 | + |
| 56 | + |
| 57 | +def load_existing_config(db_path: Path) -> Dict: |
| 58 | + cfg: Dict = {} |
| 59 | + res_json = db_path / "results.json" |
| 60 | + if res_json.exists(): |
| 61 | + try: |
| 62 | + cfg = json.loads(res_json.read_text()).get("config", {}) |
| 63 | + except Exception: |
| 64 | + cfg = {} |
| 65 | + return cfg |
| 66 | + |
| 67 | + |
| 68 | +def record( |
| 69 | + phases: Dict[str, dict], |
| 70 | + name: str, |
| 71 | + result, |
| 72 | + dur: float, |
| 73 | + rss_start: float, |
| 74 | + rss_end: float, |
| 75 | +) -> None: |
| 76 | + phases[name] = { |
| 77 | + "time_sec": dur, |
| 78 | + "rss_before_mb": rss_start, |
| 79 | + "rss_after_mb": rss_end, |
| 80 | + "rss_delta_mb": rss_end - rss_start, |
| 81 | + } |
| 82 | + if isinstance(result, dict): |
| 83 | + phases[name].update(result) |
| 84 | + |
| 85 | + |
| 86 | +def run_single( |
| 87 | + db_path: Path, |
| 88 | + dataset_dir: Path, |
| 89 | + overquery: int, |
| 90 | + k: int, |
| 91 | + quantization: str, |
| 92 | + output_root: Path, |
| 93 | + tag: str | None, |
| 94 | + base_config: Dict, |
| 95 | +) -> Path: |
| 96 | + sources, gt_path, dim, label = resolve_dataset(dataset_dir) |
| 97 | + total_rows = sum(s["count"] for s in sources) |
| 98 | + gt_full = load_ground_truth(gt_path) |
| 99 | + |
| 100 | + qids = load_queries(gt_path, limit=1000) |
| 101 | + qids = [qid for qid in qids if qid < total_rows][:1000] |
| 102 | + qids = [qid for qid in qids if qid in gt_full][:1000] |
| 103 | + if not qids: |
| 104 | + raise SystemExit("No valid query IDs with ground truth found") |
| 105 | + |
| 106 | + phases: Dict[str, dict] = {} |
| 107 | + |
| 108 | + (queries, dur, r0, r1) = timed_section( |
| 109 | + "load_queries", lambda: materialize_queries(sources, qids, dim=dim) |
| 110 | + ) |
| 111 | + record(phases, "load_queries", {"queries": len(queries)}, dur, r0, r1) |
| 112 | + |
| 113 | + (db, dur, r0, r1) = timed_section( |
| 114 | + "open_db", lambda: arcadedb.open_database(str(db_path)) |
| 115 | + ) |
| 116 | + record(phases, "open_db", {}, dur, r0, r1) |
| 117 | + |
| 118 | + index = db.schema.get_vector_index("VectorData", "vector") |
| 119 | + |
| 120 | + (warm_info, dur, r0, r1) = timed_section( |
| 121 | + "warmup", |
| 122 | + lambda: warmup(index, queries, overquery, k, quantization), |
| 123 | + ) |
| 124 | + record(phases, "warmup", warm_info, dur, r0, r1) |
| 125 | + |
| 126 | + (search_stats, dur, r0, r1) = timed_section( |
| 127 | + "search", |
| 128 | + lambda: search_index( |
| 129 | + index, |
| 130 | + queries, |
| 131 | + qids, |
| 132 | + gt_full, |
| 133 | + k=k, |
| 134 | + overquery_factor=overquery, |
| 135 | + quantization=quantization, |
| 136 | + ), |
| 137 | + ) |
| 138 | + record(phases, "search", search_stats, dur, r0, r1) |
| 139 | + |
| 140 | + try: |
| 141 | + (_, dur, r0, r1) = timed_section("close_db_final", lambda: db.close()) |
| 142 | + record(phases, "close_db_final", {}, dur, r0, r1) |
| 143 | + except Exception: |
| 144 | + pass |
| 145 | + |
| 146 | + rss_after_vals = [ |
| 147 | + v.get("rss_after_mb") |
| 148 | + for v in phases.values() |
| 149 | + if v.get("rss_after_mb") is not None |
| 150 | + ] |
| 151 | + peak_rss = max(rss_after_vals) if rss_after_vals else None |
| 152 | + |
| 153 | + recall_stats = { |
| 154 | + "search": { |
| 155 | + "mean": phases.get("search", {}).get("recall_mean"), |
| 156 | + "n": phases.get("search", {}).get("recall_count"), |
| 157 | + }, |
| 158 | + "search_after_reopen": {"mean": None, "n": None}, |
| 159 | + } |
| 160 | + |
| 161 | + latency_ms = { |
| 162 | + "search": { |
| 163 | + "mean": phases.get("search", {}).get("latency_ms_mean"), |
| 164 | + "p95": phases.get("search", {}).get("latency_ms_p95"), |
| 165 | + }, |
| 166 | + "search_after_reopen": {"mean": None, "p95": None}, |
| 167 | + } |
| 168 | + |
| 169 | + dataset_info = { |
| 170 | + "label": label or "dataset", |
| 171 | + "dim": dim, |
| 172 | + "shards": len(sources), |
| 173 | + "rows": total_rows, |
| 174 | + } |
| 175 | + |
| 176 | + config_info = { |
| 177 | + **{k: v for k, v in base_config.items() if v is not None}, |
| 178 | + "overquery_factor": overquery, |
| 179 | + "quantization": quantization, |
| 180 | + "queries": len(qids), |
| 181 | + "k": k, |
| 182 | + } |
| 183 | + |
| 184 | + results = { |
| 185 | + "dataset": dataset_info, |
| 186 | + "config": config_info, |
| 187 | + "phases": phases, |
| 188 | + "recall": recall_stats, |
| 189 | + "latency_ms": latency_ms, |
| 190 | + "db_path": str(db_path), |
| 191 | + "db_size_mb": dir_size_mb(db_path), |
| 192 | + } |
| 193 | + |
| 194 | + run_dir_name_parts = [ |
| 195 | + f"dataset={dataset_dir.name}", |
| 196 | + f"label={label or 'dataset'}", |
| 197 | + f"oq={overquery}", |
| 198 | + f"reuse={db_path.name}", |
| 199 | + ] |
| 200 | + if tag: |
| 201 | + run_dir_name_parts.append(f"tag={tag}") |
| 202 | + run_dir = output_root / "_".join(run_dir_name_parts) |
| 203 | + run_dir.mkdir(parents=True, exist_ok=True) |
| 204 | + |
| 205 | + results_json = run_dir / "results.json" |
| 206 | + results_json.write_text(json.dumps(results, indent=2)) |
| 207 | + |
| 208 | + md_lines = [ |
| 209 | + f"# ArcadeDB overquery sweep ({dataset_info['label']})", |
| 210 | + "", |
| 211 | + "## Config", |
| 212 | + f"- overquery_factor: {overquery}", |
| 213 | + f"- quantization: {quantization}", |
| 214 | + f"- k: {k}", |
| 215 | + f"- db_path: {db_path}", |
| 216 | + "", |
| 217 | + "## Recall", |
| 218 | + ( |
| 219 | + f"- search: {recall_stats['search']['mean']:.4f} (n={recall_stats['search']['n']})" |
| 220 | + if recall_stats["search"]["mean"] is not None |
| 221 | + else "- search: n/a" |
| 222 | + ), |
| 223 | + "", |
| 224 | + "## Latency (ms)", |
| 225 | + ( |
| 226 | + f"- search mean: {latency_ms['search']['mean']:.2f} | p95: {latency_ms['search']['p95']:.2f}" |
| 227 | + if latency_ms["search"]["mean"] is not None |
| 228 | + else "- search: n/a" |
| 229 | + ), |
| 230 | + "", |
| 231 | + "## Phases (time sec / RSS MB)", |
| 232 | + ] |
| 233 | + |
| 234 | + for name in ("load_queries", "open_db", "warmup", "search", "close_db_final"): |
| 235 | + if name not in phases: |
| 236 | + continue |
| 237 | + p = phases[name] |
| 238 | + line = ( |
| 239 | + f"- {name}: time={p['time_sec']:.3f}s, rss_before={p['rss_before_mb']:.1f} MB, " |
| 240 | + f"rss_after={p['rss_after_mb']:.1f} MB, delta={p['rss_delta_mb']:.1f} MB" |
| 241 | + ) |
| 242 | + if "recall_mean" in p: |
| 243 | + line += f", recall@{k}={p['recall_mean']:.4f}" |
| 244 | + if "latency_ms_mean" in p and p["latency_ms_mean"] is not None: |
| 245 | + line += f", latency_ms={p['latency_ms_mean']:.2f}" |
| 246 | + md_lines.append(line) |
| 247 | + |
| 248 | + def fmt(val: float | None) -> str: |
| 249 | + return "nan" if val is None else f"{val:.1f}" |
| 250 | + |
| 251 | + md_lines.extend( |
| 252 | + [ |
| 253 | + "", |
| 254 | + f"- db_size_mb: {fmt(results['db_size_mb'])}", |
| 255 | + f"- peak_rss_mb: {fmt(peak_rss)}", |
| 256 | + ] |
| 257 | + ) |
| 258 | + |
| 259 | + results_md = run_dir / "results.md" |
| 260 | + results_md.write_text("\n".join(md_lines)) |
| 261 | + print(f"Wrote {results_json}") |
| 262 | + print(f"Wrote {results_md}") |
| 263 | + |
| 264 | + return run_dir |
| 265 | + |
| 266 | + |
| 267 | +def main() -> None: |
| 268 | + ap = argparse.ArgumentParser( |
| 269 | + description="Run overquery sweeps against an existing ArcadeDB MSMARCO DB" |
| 270 | + ) |
| 271 | + ap.add_argument("--db-path", required=True, help="Path to existing ArcadeDB DB") |
| 272 | + ap.add_argument("--dataset-dir", required=True, help="Path to MSMARCO dataset dir") |
| 273 | + ap.add_argument( |
| 274 | + "--overquery-factors", |
| 275 | + required=True, |
| 276 | + help="Comma-separated overquery values (e.g., 1,2,4,8,16)", |
| 277 | + ) |
| 278 | + ap.add_argument( |
| 279 | + "--output-root", |
| 280 | + default="arcadedb_runs", |
| 281 | + help="Where to place per-factor results (default: arcadedb_runs)", |
| 282 | + ) |
| 283 | + ap.add_argument("--k", type=int, default=50, help="Top-K for recall/latency") |
| 284 | + ap.add_argument( |
| 285 | + "--quantization", |
| 286 | + choices=["NONE", "INT8", "BINARY", "PRODUCT"], |
| 287 | + help="Override quantization (default: from existing results.json or NONE)", |
| 288 | + ) |
| 289 | + ap.add_argument("--tag", help="Optional tag appended to output directory name") |
| 290 | + |
| 291 | + args = ap.parse_args() |
| 292 | + |
| 293 | + overqueries = parse_overqueries(args.overquery_factors) |
| 294 | + db_path = Path(args.db_path) |
| 295 | + dataset_dir = Path(args.dataset_dir) |
| 296 | + output_root = Path(args.output_root) |
| 297 | + |
| 298 | + if not db_path.exists(): |
| 299 | + raise SystemExit(f"DB path not found: {db_path}") |
| 300 | + |
| 301 | + base_config = load_existing_config(db_path) |
| 302 | + quant = (args.quantization or base_config.get("quantization") or "NONE").upper() |
| 303 | + |
| 304 | + created: List[Path] = [] |
| 305 | + for oq in overqueries: |
| 306 | + created.append( |
| 307 | + run_single( |
| 308 | + db_path=db_path, |
| 309 | + dataset_dir=dataset_dir, |
| 310 | + overquery=oq, |
| 311 | + k=args.k, |
| 312 | + quantization=quant, |
| 313 | + output_root=output_root, |
| 314 | + tag=args.tag, |
| 315 | + base_config=base_config, |
| 316 | + ) |
| 317 | + ) |
| 318 | + |
| 319 | + print("\nCompleted runs:") |
| 320 | + for p in created: |
| 321 | + print(f"- {p}") |
| 322 | + |
| 323 | + |
| 324 | +if __name__ == "__main__": |
| 325 | + main() |
0 commit comments