Skip to content

Commit 3c77771

Browse files
committed
vector ranker
1 parent 0798623 commit 3c77771

9 files changed

Lines changed: 421 additions & 8 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,17 @@ all 500 queries no filter, includes ~48% path-signal-less queries
172172
Output goes to `bench/runs/<timestamp>__<tier>/`: `report.md`, `summary.json`,
173173
`per_query.jsonl`.
174174

175+
Block mode can optionally rerank only the cross-block merge candidates before
176+
the file/directory split:
177+
178+
```bash
179+
python bench/run_swebench_filetree.py --tier medium --strategy block --ranker vector
180+
```
181+
182+
Available rankers are `none`, `bm25`, and `vector`. The vector ranker uses
183+
LiteLLM embeddings (`--embedding-provider`, `--embedding-model`) and leaves
184+
the default `ranker=none` unchanged.
185+
175186
#### Latest Run (Claude Sonnet 4.6, `--strategy block --ranker none`, top-k=10)
176187

177188
The cutoff for each query is its gold-file count: one-gold queries use top-1,

bench/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ Outputs:
2020
- `report.md`: markdown report
2121
- `bench.sqlite`: temporary benchmark database
2222

23+
Ranker options:
24+
25+
- `--ranker none`: preserve traversal and block-local LLM order.
26+
- `--ranker bm25`: lexical path ordering for cross-block merge candidates.
27+
- `--ranker vector`: embedding path ordering for cross-block merge candidates;
28+
configure with `--embedding-provider` and `--embedding-model`.
29+
2330
### Latest Full Run
2431

2532
Claude Sonnet 4.6, `tier=all`, `strategy=block`, `ranker=none`, `top_k=10`.

bench/run_swebench_filetree.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,25 @@
3131
from contextdb.api.condb import ConDB
3232
from contextdb.retriever.algorithm.beam_retriever import BeamRetriever
3333
from contextdb.retriever.algorithm.block_retriever import BlockRetriever
34-
from contextdb.retriever.algorithm.ranker import BM25PathRanker
34+
from contextdb.retriever.algorithm.ranker import make_ranker
3535

3636
DEFAULT_MODEL = "claude-sonnet-4-6"
3737
DEFAULT_DATA_DIR = Path("data/swebench_pathonly")
3838

3939

4040
def make_filesystem_retriever(db: ConDB, args, node_count: int):
41-
ranker = BM25PathRanker() if args.ranker == "bm25" else None
4241
strategy = args.strategy
4342
if strategy == "auto":
4443
strategy = "beam" if node_count <= 50 else "block"
4544
if strategy == "beam":
4645
return BeamRetriever(db.storage, db._llm, mode="filesystem")
4746
if strategy == "block":
47+
ranker = make_ranker(
48+
args.ranker,
49+
embedding_provider=args.embedding_provider,
50+
embedding_model=args.embedding_model,
51+
embedding_api_key=args.embedding_api_key,
52+
)
4853
return BlockRetriever(
4954
db.storage,
5055
db._llm,
@@ -224,6 +229,8 @@ def run(args):
224229
"top_k": args.top_k,
225230
"strategy": args.strategy,
226231
"ranker": args.ranker,
232+
"embedding_provider": args.embedding_provider if args.ranker == "vector" else None,
233+
"embedding_model": args.embedding_model if args.ranker == "vector" else None,
227234
"limit": args.limit,
228235
"num_queries": len(queries),
229236
"num_snapshots": len(by_snap),
@@ -476,8 +483,11 @@ def main():
476483
p.add_argument("--provider", default="anthropic")
477484
p.add_argument("--top-k", type=int, default=10)
478485
p.add_argument("--strategy", choices=["auto", "beam", "block"], default="auto")
479-
p.add_argument("--ranker", choices=["bm25", "none"], default="none",
486+
p.add_argument("--ranker", choices=["bm25", "vector", "none"], default="none",
480487
help="Optional path ordering for Block merge results")
488+
p.add_argument("--embedding-provider", default="openai")
489+
p.add_argument("--embedding-model", default="text-embedding-3-small")
490+
p.add_argument("--embedding-api-key", default=None)
481491
p.add_argument("--max-parallel-blocks", type=int, default=None)
482492
p.add_argument("--max-turns", type=int, default=None)
483493
p.add_argument("--limit", type=int, default=0, help="0 = all")

contextdb/api/condb.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
BlockRetriever,
1515
TreeFormatter,
1616
)
17+
from contextdb.retriever.algorithm.ranker import Ranker, make_ranker
1718

1819
# ── Errors ──────────────────────────────────────────────────────────
1920

@@ -107,6 +108,10 @@ def query(
107108
max_tokens_per_block: int = 16000,
108109
cache_enabled: bool = True,
109110
max_parallel_blocks: int = None,
111+
ranker: str | Ranker | None = None,
112+
embedding_provider: str = "openai",
113+
embedding_model: str = None,
114+
embedding_api_key: str = None,
110115
retriever: BaseRetriever = None,
111116
) -> QueryResult:
112117
self._check_tree(tree_id)
@@ -118,6 +123,10 @@ def query(
118123
max_tokens_per_block=max_tokens_per_block,
119124
cache_enabled=cache_enabled,
120125
max_parallel_blocks=max_parallel_blocks,
126+
ranker=ranker,
127+
embedding_provider=embedding_provider,
128+
embedding_model=embedding_model,
129+
embedding_api_key=embedding_api_key,
121130
)
122131

123132
result = retriever.retrieve(tree_id, question,
@@ -136,6 +145,16 @@ def _make_retriever(self, tree_id, llm, strategy, **kwargs) -> BaseRetriever:
136145
if strategy == "auto":
137146
strategy = self._pick_strategy(tree_id)
138147
mode = self._tree_mode(tree_id)
148+
ranker = (
149+
make_ranker(
150+
kwargs.get("ranker"),
151+
embedding_provider=kwargs.get("embedding_provider", "openai"),
152+
embedding_model=kwargs.get("embedding_model"),
153+
embedding_api_key=kwargs.get("embedding_api_key"),
154+
)
155+
if strategy == "block"
156+
else None
157+
)
139158
return build_strategy_retriever(
140159
self.storage,
141160
llm,
@@ -146,7 +165,7 @@ def _make_retriever(self, tree_id, llm, strategy, **kwargs) -> BaseRetriever:
146165
cache_enabled=kwargs.get("cache_enabled", True),
147166
max_parallel_blocks=kwargs.get("max_parallel_blocks"),
148167
mode=mode,
149-
ranker=kwargs.get("ranker"),
168+
ranker=ranker,
150169
)
151170

152171
def _tree_mode(self, tree_id: str) -> str:

contextdb/retriever/algorithm/block_retriever_filesystem.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,11 @@ def _order_fs_node_id_groups_for_query(
340340
for node_id in node_ids
341341
if node_id in node_by_id
342342
]
343-
if not has_path_evidence(candidates, query):
343+
should_rank = getattr(ranker, "should_rank", None)
344+
if callable(should_rank):
345+
if not should_rank(query, candidates, context={"mode": "filesystem", "tree_id": tree_id}):
346+
return node_ids
347+
elif not has_path_evidence(candidates, query):
344348
return node_ids
345349
scores = ranker.rank(
346350
query,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Embedding adapters used by retrieval rankers."""
2+
3+
from __future__ import annotations
4+
5+
import math
6+
from typing import Any, Protocol, runtime_checkable
7+
8+
9+
@runtime_checkable
10+
class EmbeddingClient(Protocol):
11+
def embed(self, texts: list[str]) -> list[list[float]]:
12+
"""Return one embedding vector per input text."""
13+
...
14+
15+
16+
def cosine_similarity(a: list[float], b: list[float]) -> float:
17+
if not a or not b:
18+
return 0.0
19+
n = min(len(a), len(b))
20+
dot = sum(a[i] * b[i] for i in range(n))
21+
norm_a = math.sqrt(sum(v * v for v in a[:n]))
22+
norm_b = math.sqrt(sum(v * v for v in b[:n]))
23+
if norm_a <= 0.0 or norm_b <= 0.0:
24+
return 0.0
25+
return dot / (norm_a * norm_b)
26+
27+
28+
def _resolve_embedding_model(provider: str, model: str | None) -> str:
29+
model = model or "text-embedding-3-small"
30+
if "/" in model:
31+
return model
32+
return f"{provider}/{model}"
33+
34+
35+
class LiteLLMEmbeddingClient:
36+
"""Thin adapter around litellm.embedding."""
37+
38+
def __init__(
39+
self,
40+
*,
41+
provider: str = "openai",
42+
model: str | None = None,
43+
api_key: str | None = None,
44+
**kwargs: Any,
45+
) -> None:
46+
import litellm
47+
48+
self._litellm = litellm
49+
self.model = _resolve_embedding_model(provider, model)
50+
self.api_key = api_key
51+
self.kwargs = kwargs
52+
litellm.suppress_debug_info = True
53+
54+
def embed(self, texts: list[str]) -> list[list[float]]:
55+
if not texts:
56+
return []
57+
kwargs = {"model": self.model, "input": texts, **self.kwargs}
58+
if self.api_key:
59+
kwargs["api_key"] = self.api_key
60+
response = self._litellm.embedding(**kwargs)
61+
rows = list(getattr(response, "data", None) or response["data"])
62+
rows.sort(key=lambda row: _row_value(row, "index", 0))
63+
return [list(_row_value(row, "embedding", [])) for row in rows]
64+
65+
66+
def _row_value(row: Any, key: str, default: Any) -> Any:
67+
if isinstance(row, dict):
68+
return row.get(key, default)
69+
return getattr(row, key, default)

0 commit comments

Comments
 (0)