Skip to content

Commit 967ea64

Browse files
Merge pull request #108 from DataScienceUIBK/demo
fix: correct Diver/Reasonir model registry, remove arena BM25 selecto…
2 parents e26f4cd + def1632 commit 967ea64

File tree

5 files changed

+246
-71
lines changed

5 files changed

+246
-71
lines changed

demo-web/src/app/agent/page.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ export default function AgentPage() {
126126
strong: ({ node, ...props }) => <strong className="font-bold text-slate-900" {...props} />,
127127
code: ({ node, inline, ...props }: any) =>
128128
inline ? (
129-
<code className="bg-slate-100 text-pink-600 px-1.5 py-0.5 rounded-md text-[13px] font-mono" {...props} />
129+
<code className="bg-slate-100 text-violet-700 px-1.5 py-0.5 rounded-md text-[13px] font-mono font-semibold" {...props} />
130130
) : (
131131
<div className="my-4 rounded-xl overflow-hidden border border-slate-200 bg-[#0d1117] shadow-sm">
132132
<div className="flex items-center px-4 py-2 bg-slate-800 border-b border-slate-700">
@@ -136,8 +136,8 @@ export default function AgentPage() {
136136
<div className="w-3 h-3 rounded-full bg-green-500/80"></div>
137137
</div>
138138
</div>
139-
<pre className="p-4 overflow-x-auto">
140-
<code className="text-[13px] font-mono text-slate-50 leading-relaxed" {...props} />
139+
<pre className="p-4 overflow-x-auto bg-[#0d1117]">
140+
<code className="text-[13px] font-mono text-white leading-relaxed whitespace-pre-wrap" {...props} />
141141
</pre>
142142
</div>
143143
),

demo-web/src/app/arena/page.tsx

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,10 @@ export default function ArenaPage() {
157157
Pipeline A (Baseline)
158158
</div>
159159
<div className="flex flex-col gap-4">
160-
<Sel value={pipeA.retriever} onChange={() => { }} opts={[{ value: "bm25", label: "BM25 (Sparse Base)" }]} label="Retriever (Fixed by BEIR)" icon={Search} />
161160
<Sel
162161
value={pipeA.method}
163162
onChange={v => setPipeA(p => ({ ...p, method: v, model: v === "none" ? "none" : (RERANKERS_MAP[v as keyof typeof RERANKERS_MAP]?.[0] || "") }))}
164-
opts={METHODS.map(m => ({ value: m, label: m === "none" ? "None (Base BM25 Only)" : m }))}
163+
opts={METHODS.map(m => ({ value: m, label: m === "none" ? "None (BM25 baseline only)" : m }))}
165164
label="Reranking Method"
166165
icon={ListTree}
167166
/>
@@ -184,11 +183,10 @@ export default function ArenaPage() {
184183
Pipeline B (Challenger)
185184
</div>
186185
<div className="flex flex-col gap-4">
187-
<Sel value={pipeB.retriever} onChange={() => { }} opts={[{ value: "bm25", label: "BM25 (Sparse Base)" }]} label="Retriever (Fixed by BEIR)" icon={Search} />
188186
<Sel
189187
value={pipeB.method}
190188
onChange={v => setPipeB(p => ({ ...p, method: v, model: v === "none" ? "none" : (RERANKERS_MAP[v as keyof typeof RERANKERS_MAP]?.[0] || "") }))}
191-
opts={METHODS.map(m => ({ value: m, label: m === "none" ? "None (Base BM25 Only)" : m }))}
189+
opts={METHODS.map(m => ({ value: m, label: m === "none" ? "None (BM25 baseline only)" : m }))}
192190
label="Reranking Method"
193191
icon={ListTree}
194192
/>

demo_server.py

Lines changed: 41 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -498,81 +498,60 @@ async def arena_run(req: ArenaRequest):
498498
eval_docs = random.sample(documents, min(req.n_queries, len(documents)))
499499

500500
def evaluate_pipeline(pipeline_cfg: ArenaPipeline, docs):
501-
import copy
501+
import copy, math
502502
docs_copy = copy.deepcopy(docs)
503503

504-
# Retrieval
505-
idx_type = "msmarco" if req.dataset == "msmarco" else "wiki"
506-
retriever = get_retriever(pipeline_cfg.retriever, n_docs=req.n_docs, index_type=idx_type)
507-
t0 = time.time()
508-
ret_results = retriever.retrieve(docs_copy)
509-
ret_latency = (time.time() - t0) * 1000 / len(docs_copy)
504+
# NOTE: BEIR datasets are already pre-retrieved with BM25 – the downloaded
505+
# JSON files contain ranked contexts with `has_answer` set.
506+
# Re-calling retriever.retrieve() would reset those contexts and lose the
507+
# relevance labels, making all metrics come out as 0.
508+
# So we skip re-retrieval; we only rerank if a category is configured.
509+
ret_latency = 0.0
510+
rr_latency = 0.0
511+
ret_results = docs_copy
510512

511513
# Reranking
512-
rr_latency = 0
513514
reranker = get_reranker(pipeline_cfg.rerankerCategory, pipeline_cfg.rerankerModel)
514515
if reranker:
515516
t1 = time.time()
516517
ret_results = reranker.rank(ret_results)
517-
rr_latency = (time.time() - t1) * 1000 / len(docs_copy)
518+
rr_latency = (time.time() - t1) * 1000 / max(1, len(docs_copy))
518519

519-
# Evaluation - use true TREC evaluation as per BEIR standards
520-
metrics = Metrics(ret_results)
520+
# Evaluate: pure-Python NDCG@10 and MRR@10 using has_answer flags
521521
use_rr = reranker is not None
522+
mrr_sum = 0.0
523+
ndcg_sum = 0.0
522524

523-
# Formulate the correct qrel name for rankify
524-
qrel_name = req.dataset
525-
if req.dataset.startswith("beir-"):
526-
qrel_name = req.dataset.split("-")[1]
527-
elif req.dataset in ["nq-dev", "msmarco", "triviaqa"]:
528-
qrel_name = req.dataset
525+
for doc in ret_results:
526+
contexts = doc.reorder_contexts if (use_rr and getattr(doc, "reorder_contexts", None)) else doc.contexts
527+
if not contexts:
528+
continue
529529

530-
try:
531-
trec_metrics = metrics.calculate_trec_metrics(
532-
ndcg_cuts=[10],
533-
map_cuts=[10],
534-
mrr_cuts=[10],
535-
qrel=qrel_name,
536-
use_reordered=use_rr
537-
)
538-
ndcg_10 = trec_metrics.get("ndcg@10", 0) * 100
539-
mrr_10 = trec_metrics.get("mrr@10", 0) * 100
540-
except Exception as e:
541-
logger.error(f"TREC Eval Error: {e}")
542-
ndcg_10, mrr_10 = 0, 0
543-
544-
# FALLBACK: If pyserini fails or returns 0.0 (happens on Python 3.13),
545-
# we use a manual calculation based on doc.has_answer.
546-
if ndcg_10 == 0 and mrr_10 == 0:
547-
logger.warning(f"TREC Eval returned 0.0 for {qrel_name}, using manual fallback.")
548-
import math
549-
mrr_sum = 0
550-
ndcg_sum = 0
551-
for doc in ret_results:
552-
contexts = doc.reorder_contexts if (use_rr and doc.reorder_contexts) else doc.contexts
553-
# MRR
554-
found_at = -1
555-
for i, ctx in enumerate(contexts[:10]):
556-
if getattr(ctx, "has_answer", False):
557-
found_at = i + 1
558-
break
559-
if found_at > 0: mrr_sum += 1.0 / found_at
560-
561-
# NDCG (Binary)
562-
dcg = 0
563-
hits_rels = []
564-
for i, ctx in enumerate(contexts[:10]):
565-
rel = 1 if getattr(ctx, "has_answer", False) else 0
566-
hits_rels.append(rel)
567-
if rel: dcg += 1.0 / math.log2(i + 2)
568-
569-
hits_rels.sort(reverse=True)
570-
idcg = sum(1.0 / math.log2(i + 2) for i, rel in enumerate(hits_rels) if rel)
571-
if idcg > 0: ndcg_sum += (dcg / idcg)
530+
# MRR@10
531+
for i, ctx in enumerate(contexts[:10]):
532+
if getattr(ctx, "has_answer", False):
533+
mrr_sum += 1.0 / (i + 1)
534+
break
535+
536+
# NDCG@10 (binary relevance)
537+
dcg = 0.0
538+
rels = []
539+
for i, ctx in enumerate(contexts[:10]):
540+
rel = 1 if getattr(ctx, "has_answer", False) else 0
541+
rels.append(rel)
542+
if rel:
543+
dcg += 1.0 / math.log2(i + 2)
572544

573-
n = len(ret_results)
574-
mrr_10 = (mrr_sum / n) * 100 if n > 0 else 0
575-
ndcg_10 = (ndcg_sum / n) * 100 if n > 0 else 0
545+
rels_sorted = sorted(rels, reverse=True)
546+
idcg = sum(r / math.log2(i + 2) for i, r in enumerate(rels_sorted) if r)
547+
if idcg > 0:
548+
ndcg_sum += dcg / idcg
549+
550+
n = len(ret_results)
551+
mrr_10 = (mrr_sum / n) * 100 if n > 0 else 0.0
552+
ndcg_10 = (ndcg_sum / n) * 100 if n > 0 else 0.0
553+
554+
logger.info(f"Arena eval: n={n} NDCG@10={ndcg_10:.2f}% MRR@10={mrr_10:.2f}%")
576555

577556
return {
578557
"mrr_10": mrr_10,

rankify/agent/agent.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,39 @@ class RankifyAgent:
5454
Rankify is a comprehensive Python toolkit for Retrieval, Re-Ranking, and Retrieval-Augmented Generation (RAG).
5555
5656
Your job is to help users select the best models for their use case. You have access to:
57-
- 10 retrieval methods (BM25, DPR, ANCE, BGE, ColBERT, Contriever, HyDE, Online)
57+
- **Sparse Retrievers**: BM25 (fast, no GPU, exact match)
58+
- **Dense Retrievers**: DPR, ANCE, BGE, ColBERT, Contriever
59+
- **Diver Dense Retrievers** (method="diver-dense"): Many bi-encoder and LLM-based variants selectable via model_id
60+
- **Reasoning-Augmented Retrievers** (SOTA on BRIGHT benchmark): ReasonIR-8B, ReasonEmbed, BGE-Reasoner-Embed
61+
- **Online Retriever**: Web search via APIs (real-time data)
62+
- **HyDE**: Hypothetical Document Embedding for complex queries
5863
- 23 reranking methods (MonoT5, FlashRank, RankGPT, InRanker, ColBERT, API rerankers, etc.)
5964
- 7 RAG methods (Basic RAG, Chain-of-Thought, Self-Consistency, ReAct, FiD, etc.)
6065
66+
**Diver Dense Retriever Guide (method="diver-dense"):**
67+
Valid model_ids (must have corpus_path):
68+
- `bge` → BAAI/bge-large-en-v1.5
69+
- `sbert` → sentence-transformers/all-mpnet-base-v2
70+
- `nomic` → nomic-ai/nomic-embed-text-v1
71+
- `diver` → AQ-MedAI/Diver-Retriever-4B (flagship diverse evidence model)
72+
- `inst-l` → hkunlp/instructor-large
73+
- `inst-xl` → hkunlp/instructor-xl
74+
- `e5` → intfloat/e5-mistral-7b-instruct (LLM-based)
75+
- `sf` → Salesforce/SFR-Embedding-Mistral (LLM-based)
76+
- `rader` → Raderspace/RaDeR_Qwen_25_7B (reasoning-augmented)
77+
- `grit` → GritLM/GritLM-7B (generative representation)
78+
- `m2` → togethercomputer/m2-bert-80M-32k-retrieval (long context, 32k)
79+
- `contriever` → facebook/contriever-msmarco
80+
81+
Example: `Retriever(method="diver-dense", model_id="diver", corpus_path="data/corpus.jsonl", n_docs=10)`
82+
83+
**Reasoning-Augmented Retrievers:**
84+
- `Retriever(method="reasonir", corpus_path=...)` → reasonir/ReasonIR-8B (SOTA BRIGHT benchmark)
85+
- `Retriever(method="reason-embed", model_id="qwen3-8b"|"qwen3-4b"|"llama-8b", corpus_path=...)`
86+
- `Retriever(method="bge-reasoner-embed", corpus_path=...)` → BAAI/bge-reasoner-embed-qwen3-8b
87+
6188
When helping users, consider:
62-
1. Their task type (QA, search, summarization, conversational)
89+
1. Their task type (QA, search, summarization, conversational, reasoning-intensive)
6390
2. Hardware constraints (GPU availability, memory)
6491
3. Latency requirements
6592
4. Whether they can use APIs or need local models

rankify/agent/model_registry.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,180 @@ def score_for_task(self, task: TaskType) -> float:
187187
api_provider="serper",
188188
best_for=["real-time data", "current events", "web search"],
189189
),
190+
191+
# === DIVER DENSE RETRIEVERS ===
192+
# method="diver-dense", use model_path as the model_id argument
193+
# Valid model_ids: bge, sbert, contriever_st, nomic, diver, inst-l, inst-xl, sf, e5, rader, m2, contriever, grit
194+
"diver-bge": ModelMetadata(
195+
name="Diver (BGE Large)",
196+
method="diver-dense",
197+
description="BAAI/bge-large-en-v1.5 via the Diver dense retrieval framework.",
198+
speed=Speed.MEDIUM,
199+
accuracy=Accuracy.STATE_OF_THE_ART,
200+
gpu_required=True,
201+
memory_mb=3000,
202+
best_for=["semantic search", "high accuracy", "BEIR benchmarks"],
203+
model_path="bge",
204+
),
205+
"diver-sbert": ModelMetadata(
206+
name="Diver (SBERT all-mpnet-base-v2)",
207+
method="diver-dense",
208+
description="sentence-transformers/all-mpnet-base-v2 via the Diver framework.",
209+
speed=Speed.FAST,
210+
accuracy=Accuracy.VERY_GOOD,
211+
gpu_required=True,
212+
memory_mb=1500,
213+
best_for=["sentence similarity", "semantic search"],
214+
model_path="sbert",
215+
),
216+
"diver-nomic": ModelMetadata(
217+
name="Diver (Nomic Embed)",
218+
method="diver-dense",
219+
description="nomic-ai/nomic-embed-text-v1 via the Diver framework.",
220+
speed=Speed.FAST,
221+
accuracy=Accuracy.VERY_GOOD,
222+
gpu_required=True,
223+
memory_mb=1500,
224+
best_for=["long context", "semantic search", "document retrieval"],
225+
model_path="nomic",
226+
),
227+
"diver-e5": ModelMetadata(
228+
name="Diver (E5-Mistral-7B)",
229+
method="diver-dense",
230+
description="intfloat/e5-mistral-7b-instruct — instruction-tuned LLM encoder in the Diver framework.",
231+
speed=Speed.SLOW,
232+
accuracy=Accuracy.STATE_OF_THE_ART,
233+
gpu_required=True,
234+
memory_mb=16000,
235+
best_for=["zero-shot retrieval", "instruction following", "complex queries"],
236+
model_path="e5",
237+
),
238+
"diver-sf": ModelMetadata(
239+
name="Diver (SFR-Embedding-Mistral)",
240+
method="diver-dense",
241+
description="Salesforce/SFR-Embedding-Mistral — Salesforce Mistral-based bi-encoder in the Diver framework.",
242+
speed=Speed.SLOW,
243+
accuracy=Accuracy.STATE_OF_THE_ART,
244+
gpu_required=True,
245+
memory_mb=16000,
246+
best_for=["high accuracy retrieval", "complex queries", "BEIR benchmarks"],
247+
model_path="sf",
248+
),
249+
"diver-rader": ModelMetadata(
250+
name="Diver (RaDeR)",
251+
method="diver-dense",
252+
description="Raderspace/RaDeR_Qwen_25_7B — reasoning-aware dense retriever in the Diver framework.",
253+
speed=Speed.SLOW,
254+
accuracy=Accuracy.EXCELLENT,
255+
gpu_required=True,
256+
memory_mb=16000,
257+
best_for=["reasoning-intensive queries", "multi-hop QA", "math-related retrieval"],
258+
model_path="rader",
259+
),
260+
"diver-grit": ModelMetadata(
261+
name="Diver (GritLM-7B)",
262+
method="diver-dense",
263+
description="GritLM/GritLM-7B — generative representation model in the Diver framework.",
264+
speed=Speed.VERY_SLOW,
265+
accuracy=Accuracy.STATE_OF_THE_ART,
266+
gpu_required=True,
267+
memory_mb=16000,
268+
best_for=["generative retrieval", "LLM-quality embeddings", "long context"],
269+
model_path="grit",
270+
),
271+
"diver-model": ModelMetadata(
272+
name="Diver Retriever-4B",
273+
method="diver-dense",
274+
description="AQ-MedAI/Diver-Retriever-4B — the flagship Diver diverse-evidence retrieval model.",
275+
speed=Speed.SLOW,
276+
accuracy=Accuracy.STATE_OF_THE_ART,
277+
gpu_required=True,
278+
memory_mb=8000,
279+
best_for=["diverse evidence retrieval", "BEIR benchmarks", "medical QA"],
280+
model_path="diver",
281+
),
282+
"diver-inst-l": ModelMetadata(
283+
name="Diver (Instructor-Large)",
284+
method="diver-dense",
285+
description="hkunlp/instructor-large — instruction-following encoder in the Diver framework.",
286+
speed=Speed.MEDIUM,
287+
accuracy=Accuracy.VERY_GOOD,
288+
gpu_required=True,
289+
memory_mb=3000,
290+
best_for=["instruction following", "domain-specific retrieval"],
291+
model_path="inst-l",
292+
),
293+
"diver-m2": ModelMetadata(
294+
name="Diver (M2-BERT-32K)",
295+
method="diver-dense",
296+
description="togethercomputer/m2-bert-80M-32k-retrieval — long-context retrieval in the Diver framework.",
297+
speed=Speed.MEDIUM,
298+
accuracy=Accuracy.VERY_GOOD,
299+
gpu_required=True,
300+
memory_mb=2000,
301+
best_for=["long-context retrieval", "32k sequence length"],
302+
model_path="m2",
303+
),
304+
305+
# === REASONING-AUGMENTED RETRIEVERS ===
306+
"reasonir": ModelMetadata(
307+
name="ReasonIR-8B",
308+
method="reasonir",
309+
description="reasonir/ReasonIR-8B — SOTA reasoning-intensive retriever on the BRIGHT benchmark. No model_id needed.",
310+
speed=Speed.VERY_SLOW,
311+
accuracy=Accuracy.STATE_OF_THE_ART,
312+
gpu_required=True,
313+
memory_mb=16000,
314+
best_for=["reasoning-intensive queries", "BRIGHT benchmark", "complex multi-hop QA", "science queries"],
315+
),
316+
"reason-embed-qwen3-8b": ModelMetadata(
317+
name="ReasonEmbed Qwen3-8B",
318+
method="reason-embed",
319+
description="hanhainebula/reason-embed-qwen3-8b-0928 — Qwen3-8B for reasoning retrieval. Use model_id='qwen3-8b'.",
320+
speed=Speed.VERY_SLOW,
321+
accuracy=Accuracy.STATE_OF_THE_ART,
322+
gpu_required=True,
323+
memory_mb=16000,
324+
best_for=["reasoning-intensive retrieval", "complex queries"],
325+
model_path="qwen3-8b",
326+
),
327+
"reason-embed-qwen3-4b": ModelMetadata(
328+
name="ReasonEmbed Qwen3-4B",
329+
method="reason-embed",
330+
description="hanhainebula/reason-embed-qwen3-4b-0928 — balanced Qwen3-4B for reasoning retrieval. Use model_id='qwen3-4b'.",
331+
speed=Speed.SLOW,
332+
accuracy=Accuracy.EXCELLENT,
333+
gpu_required=True,
334+
memory_mb=8000,
335+
best_for=["reasoning retrieval", "balanced accuracy/speed"],
336+
model_path="qwen3-4b",
337+
),
338+
"reason-embed-llama-8b": ModelMetadata(
339+
name="ReasonEmbed LLaMA-3.1-8B",
340+
method="reason-embed",
341+
description="hanhainebula/reason-embed-llama-3.1-8b-0928 — LLaMA-3.1-8B for reasoning retrieval. Use model_id='llama-8b'.",
342+
speed=Speed.SLOW,
343+
accuracy=Accuracy.EXCELLENT,
344+
gpu_required=True,
345+
memory_mb=16000,
346+
best_for=["reasoning retrieval", "open-source LLaMA backbone"],
347+
model_path="llama-8b",
348+
),
349+
"bge-reasoner-embed": ModelMetadata(
350+
name="BGE Reasoner Embed (Qwen3-8B)",
351+
method="bge-reasoner-embed",
352+
description="BAAI/bge-reasoner-embed-qwen3-8b-0923 — BGE reasoning-augmented retriever. No model_id needed.",
353+
speed=Speed.SLOW,
354+
accuracy=Accuracy.EXCELLENT,
355+
gpu_required=True,
356+
memory_mb=16000,
357+
best_for=["reasoning-augmented retrieval", "BEIR benchmarks", "complex queries"],
358+
),
190359
}
191360

192361

362+
363+
193364
# =============================================================================
194365
# RERANKER REGISTRY
195366
# =============================================================================

0 commit comments

Comments
 (0)