Skip to content

Commit a43668c

Browse files
fix(arena): download real QREL files from HuggingFace and use Metrics.calculate_trec_metrics()
1 parent 5fa147b commit a43668c

1 file changed

Lines changed: 62 additions & 69 deletions

File tree

demo_server.py

Lines changed: 62 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -479,85 +479,78 @@ async def gen():
479479
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
480480

481481

482+
482483
@app.post("/api/arena/run")
483484
async def arena_run(req: ArenaRequest):
484-
"""Compare two pipelines on a dataset."""
485+
"""Compare two pipelines on a dataset using Rankify's BEIR evaluation."""
486+
import copy, math, tempfile, os, requests
487+
485488
try:
486489
from rankify.dataset.dataset import Dataset
487490
from rankify.metrics.metrics import Metrics
488-
489-
# We need a generic way to fetch queries. Dataset.load_dataset_qa can be used if we download it.
490-
# But Dataset.download() loads the whole dataset.
491-
# We will use the Dataset class to download / get the documents.
491+
492492
logger.info(f"Arena: Running benchmark on {req.dataset}")
493+
494+
# ── QREL file download ──────────────────────────────────────────────
495+
# Pyserini is broken on Python 3.13 (jar issue), so we download qrel
496+
# files directly from the HuggingFace mirror that pyserini uses.
497+
# pyserini dataset-id → HF path on castorini/anserini-tools
498+
PYSERINI_QREL_URLS = {
499+
"dl19": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/dl19-passage.trec",
500+
"dl20": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/dl20-passage.trec",
501+
"covid": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.covid.qrels",
502+
"nfc": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.nfcorpus.qrels",
503+
"touche": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.touche.qrels",
504+
"dbpedia": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.dbpedia.qrels",
505+
"scifact": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.scifact.qrels",
506+
"signal": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.signal.qrels",
507+
"news": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.news.qrels",
508+
"robust04":"https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.robust04.qrels",
509+
"arguana": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.arguana.qrels",
510+
"fever": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.fever.qrels",
511+
"fiqa": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.fiqa.qrels",
512+
"quora": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.quora.qrels",
513+
"scidocs": "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.scidocs.qrels",
514+
}
515+
516+
# Determine the short qrel key from dataset name (e.g. "beir-covid" → "covid")
517+
dataset_key = req.dataset
518+
if req.dataset.startswith("beir-"):
519+
dataset_key = req.dataset.split("-", 1)[1]
520+
521+
# Download qrel file (cached per run)
522+
qrel_path = None
523+
qrel_cache_dir = os.path.join(os.environ.get("RERANKING_CACHE_DIR", "./cache"), "qrels")
524+
os.makedirs(qrel_cache_dir, exist_ok=True)
525+
qrel_cache_file = os.path.join(qrel_cache_dir, f"{dataset_key}.qrel")
526+
527+
if os.path.exists(qrel_cache_file):
528+
qrel_path = qrel_cache_file
529+
logger.info(f"Using cached QREL: {qrel_cache_file}")
530+
elif dataset_key in PYSERINI_QREL_URLS:
531+
url = PYSERINI_QREL_URLS[dataset_key]
532+
logger.info(f"Downloading QREL from {url}")
533+
try:
534+
resp = requests.get(url, timeout=30)
535+
if resp.status_code == 200:
536+
with open(qrel_cache_file, "w") as f:
537+
f.write(resp.text)
538+
qrel_path = qrel_cache_file
539+
logger.info(f"QREL downloaded to {qrel_cache_file}, {len(resp.text)} chars")
540+
else:
541+
logger.warning(f"QREL download failed: HTTP {resp.status_code}")
542+
except Exception as e:
543+
logger.warning(f"QREL download error: {e}")
544+
545+
# ── Dataset download ────────────────────────────────────────────────
493546
ds = Dataset(retriever="bm25", dataset_name=req.dataset, n_docs=req.n_docs)
494547
documents = ds.download(force_download=False)
495-
548+
if not documents:
549+
raise ValueError(f"Failed to load dataset: {req.dataset}")
550+
496551
import random
497-
# Select N random documents to evaluate
498552
eval_docs = random.sample(documents, min(req.n_queries, len(documents)))
499-
500-
def evaluate_pipeline(pipeline_cfg: ArenaPipeline, docs):
501-
import copy, math
502-
docs_copy = copy.deepcopy(docs)
503-
504-
# NOTE: BEIR datasets are already pre-retrieved with BM25 – the downloaded
505-
# JSON files contain ranked contexts with `has_answer` set.
506-
# Re-calling retriever.retrieve() would reset those contexts and lose the
507-
# relevance labels, making all metrics come out as 0.
508-
# So we skip re-retrieval; we only rerank if a category is configured.
509-
ret_latency = 0.0
510-
rr_latency = 0.0
511-
ret_results = docs_copy
512-
513-
# Reranking
514-
reranker = get_reranker(pipeline_cfg.rerankerCategory, pipeline_cfg.rerankerModel)
515-
if reranker:
516-
t1 = time.time()
517-
ret_results = reranker.rank(ret_results)
518-
rr_latency = (time.time() - t1) * 1000 / max(1, len(docs_copy))
519-
520-
# Evaluate: pure-Python NDCG@10 and MRR@10 using has_answer flags
521-
use_rr = reranker is not None
522-
mrr_sum = 0.0
523-
ndcg_sum = 0.0
524-
525-
for doc in ret_results:
526-
contexts = doc.reorder_contexts if (use_rr and getattr(doc, "reorder_contexts", None)) else doc.contexts
527-
if not contexts:
528-
continue
529-
530-
# MRR@10
531-
for i, ctx in enumerate(contexts[:10]):
532-
if getattr(ctx, "has_answer", False):
533-
mrr_sum += 1.0 / (i + 1)
534-
break
535-
536-
# NDCG@10 (binary relevance)
537-
dcg = 0.0
538-
rels = []
539-
for i, ctx in enumerate(contexts[:10]):
540-
rel = 1 if getattr(ctx, "has_answer", False) else 0
541-
rels.append(rel)
542-
if rel:
543-
dcg += 1.0 / math.log2(i + 2)
544-
545-
rels_sorted = sorted(rels, reverse=True)
546-
idcg = sum(r / math.log2(i + 2) for i, r in enumerate(rels_sorted) if r)
547-
if idcg > 0:
548-
ndcg_sum += dcg / idcg
549-
550-
n = len(ret_results)
551-
mrr_10 = (mrr_sum / n) * 100 if n > 0 else 0.0
552-
ndcg_10 = (ndcg_sum / n) * 100 if n > 0 else 0.0
553-
554-
logger.info(f"Arena eval: n={n} NDCG@10={ndcg_10:.2f}% MRR@10={mrr_10:.2f}%")
555-
556-
return {
557-
"mrr_10": mrr_10,
558-
"ndcg_10": ndcg_10,
559-
"latency_ms": ret_latency + rr_latency
560-
}
553+
logger.info(f"Evaluating {len(eval_docs)} queries from {req.dataset}")
561554

562555
res_a = evaluate_pipeline(req.pipeline_a, eval_docs)
563556
res_b = evaluate_pipeline(req.pipeline_b, eval_docs)

0 commit comments

Comments
 (0)