@@ -479,85 +479,78 @@ async def gen():
479479 headers = {"Cache-Control" : "no-cache" , "X-Accel-Buffering" : "no" })
480480
481481
482+
482483@app .post ("/api/arena/run" )
483484async def arena_run (req : ArenaRequest ):
484- """Compare two pipelines on a dataset."""
485+ """Compare two pipelines on a dataset using Rankify's BEIR evaluation."""
486+ import copy , math , tempfile , os , requests
487+
485488 try :
486489 from rankify .dataset .dataset import Dataset
487490 from rankify .metrics .metrics import Metrics
488-
489- # We need a generic way to fetch queries. Dataset.load_dataset_qa can be used if we download it.
490- # But Dataset.download() loads the whole dataset.
491- # We will use the Dataset class to download / get the documents.
491+
492492 logger .info (f"Arena: Running benchmark on { req .dataset } " )
493+
494+ # ── QREL file download ──────────────────────────────────────────────
495+ # Pyserini is broken on Python 3.13 (jar issue), so we download qrel
496+ # files directly from the HuggingFace mirror that pyserini uses.
497+ # pyserini dataset-id → HF path on castorini/anserini-tools
498+ PYSERINI_QREL_URLS = {
499+ "dl19" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/dl19-passage.trec" ,
500+ "dl20" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/dl20-passage.trec" ,
501+ "covid" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.covid.qrels" ,
502+ "nfc" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.nfcorpus.qrels" ,
503+ "touche" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.touche.qrels" ,
504+ "dbpedia" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.dbpedia.qrels" ,
505+ "scifact" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.scifact.qrels" ,
506+ "signal" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.signal.qrels" ,
507+ "news" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.news.qrels" ,
508+ "robust04" :"https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.robust04.qrels" ,
509+ "arguana" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.arguana.qrels" ,
510+ "fever" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.fever.qrels" ,
511+ "fiqa" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.fiqa.qrels" ,
512+ "quora" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.quora.qrels" ,
513+ "scidocs" : "https://huggingface.co/datasets/castorini/beir-qrels/resolve/main/test.scidocs.qrels" ,
514+ }
515+
516+ # Determine the short qrel key from dataset name (e.g. "beir-covid" → "covid")
517+ dataset_key = req .dataset
518+ if req .dataset .startswith ("beir-" ):
519+ dataset_key = req .dataset .split ("-" , 1 )[1 ]
520+
521+ # Download qrel file (cached per run)
522+ qrel_path = None
523+ qrel_cache_dir = os .path .join (os .environ .get ("RERANKING_CACHE_DIR" , "./cache" ), "qrels" )
524+ os .makedirs (qrel_cache_dir , exist_ok = True )
525+ qrel_cache_file = os .path .join (qrel_cache_dir , f"{ dataset_key } .qrel" )
526+
527+ if os .path .exists (qrel_cache_file ):
528+ qrel_path = qrel_cache_file
529+ logger .info (f"Using cached QREL: { qrel_cache_file } " )
530+ elif dataset_key in PYSERINI_QREL_URLS :
531+ url = PYSERINI_QREL_URLS [dataset_key ]
532+ logger .info (f"Downloading QREL from { url } " )
533+ try :
534+ resp = requests .get (url , timeout = 30 )
535+ if resp .status_code == 200 :
536+ with open (qrel_cache_file , "w" ) as f :
537+ f .write (resp .text )
538+ qrel_path = qrel_cache_file
539+ logger .info (f"QREL downloaded to { qrel_cache_file } , { len (resp .text )} chars" )
540+ else :
541+ logger .warning (f"QREL download failed: HTTP { resp .status_code } " )
542+ except Exception as e :
543+ logger .warning (f"QREL download error: { e } " )
544+
545+ # ── Dataset download ────────────────────────────────────────────────
493546 ds = Dataset (retriever = "bm25" , dataset_name = req .dataset , n_docs = req .n_docs )
494547 documents = ds .download (force_download = False )
495-
548+ if not documents :
549+ raise ValueError (f"Failed to load dataset: { req .dataset } " )
550+
496551 import random
497- # Select N random documents to evaluate
498552 eval_docs = random .sample (documents , min (req .n_queries , len (documents )))
499-
500- def evaluate_pipeline (pipeline_cfg : ArenaPipeline , docs ):
501- import copy , math
502- docs_copy = copy .deepcopy (docs )
503-
504- # NOTE: BEIR datasets are already pre-retrieved with BM25 – the downloaded
505- # JSON files contain ranked contexts with `has_answer` set.
506- # Re-calling retriever.retrieve() would reset those contexts and lose the
507- # relevance labels, making all metrics come out as 0.
508- # So we skip re-retrieval; we only rerank if a category is configured.
509- ret_latency = 0.0
510- rr_latency = 0.0
511- ret_results = docs_copy
512-
513- # Reranking
514- reranker = get_reranker (pipeline_cfg .rerankerCategory , pipeline_cfg .rerankerModel )
515- if reranker :
516- t1 = time .time ()
517- ret_results = reranker .rank (ret_results )
518- rr_latency = (time .time () - t1 ) * 1000 / max (1 , len (docs_copy ))
519-
520- # Evaluate: pure-Python NDCG@10 and MRR@10 using has_answer flags
521- use_rr = reranker is not None
522- mrr_sum = 0.0
523- ndcg_sum = 0.0
524-
525- for doc in ret_results :
526- contexts = doc .reorder_contexts if (use_rr and getattr (doc , "reorder_contexts" , None )) else doc .contexts
527- if not contexts :
528- continue
529-
530- # MRR@10
531- for i , ctx in enumerate (contexts [:10 ]):
532- if getattr (ctx , "has_answer" , False ):
533- mrr_sum += 1.0 / (i + 1 )
534- break
535-
536- # NDCG@10 (binary relevance)
537- dcg = 0.0
538- rels = []
539- for i , ctx in enumerate (contexts [:10 ]):
540- rel = 1 if getattr (ctx , "has_answer" , False ) else 0
541- rels .append (rel )
542- if rel :
543- dcg += 1.0 / math .log2 (i + 2 )
544-
545- rels_sorted = sorted (rels , reverse = True )
546- idcg = sum (r / math .log2 (i + 2 ) for i , r in enumerate (rels_sorted ) if r )
547- if idcg > 0 :
548- ndcg_sum += dcg / idcg
549-
550- n = len (ret_results )
551- mrr_10 = (mrr_sum / n ) * 100 if n > 0 else 0.0
552- ndcg_10 = (ndcg_sum / n ) * 100 if n > 0 else 0.0
553-
554- logger .info (f"Arena eval: n={ n } NDCG@10={ ndcg_10 :.2f} % MRR@10={ mrr_10 :.2f} %" )
555-
556- return {
557- "mrr_10" : mrr_10 ,
558- "ndcg_10" : ndcg_10 ,
559- "latency_ms" : ret_latency + rr_latency
560- }
553+ logger .info (f"Evaluating { len (eval_docs )} queries from { req .dataset } " )
561554
562555 res_a = evaluate_pipeline (req .pipeline_a , eval_docs )
563556 res_b = evaluate_pipeline (req .pipeline_b , eval_docs )
0 commit comments