Skip to content

Commit 498ec61

Browse files
committed
simplify the process of building a new reranker by creating a compute_similarity method
1 parent 4a8d9bd commit 498ec61

4 files changed

Lines changed: 107 additions & 152 deletions

File tree

src/vectorcode/subcommands/query/reranker/base.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1+
import heapq
12
import logging
23
from abc import ABC, abstractmethod
3-
from typing import Any
4+
from collections import defaultdict
5+
from typing import Any, DefaultDict, Optional, Sequence, cast
46

7+
import numpy
58
from chromadb.api.types import QueryResult
69

7-
from vectorcode.cli_utils import Config
10+
from vectorcode.cli_utils import Config, QueryInclude
811

912
logger = logging.getLogger(name=__name__)
1013

1114

1215
class RerankerBase(ABC):
1316
"""This is the base class for the rerankers.
1417
You should use the configs.reranker_params field to store and pass the parameters used for your reranker.
15-
You should implement the `rerank` method, which returns a list of chunk IDs if QueryInclude.chunk is in configs.include, or a list of paths otherwise.
18+
You should implement the `compute_similarity` method, which will be called by `rerank` to compute
19+
similarity scores between search query and results.
1620
The items in the returned list should be sorted such that the relevance decreases along the list.
1721
1822
The class doc string will be added to the error message if your reranker fails to initialise.
@@ -21,7 +25,11 @@ class RerankerBase(ABC):
2125

2226
def __init__(self, configs: Config, **kwargs: Any):
2327
self.configs = configs
28+
assert self.configs.query is not None, (
29+
"'configs' should contain the query messages."
30+
)
2431
self.n_result = configs.n_result
32+
self._raw_results: Optional[QueryResult] = None
2533

2634
@classmethod
2735
def create(cls, configs: Config, **kwargs: Any):
@@ -38,5 +46,48 @@ def create(cls, configs: Config, **kwargs: Any):
3846
raise
3947

4048
@abstractmethod
41-
def rerank(self, results: QueryResult) -> list[str]: # pragma: nocover
49+
def compute_similarity(
50+
self, results: list[str], query_message: str
51+
) -> Sequence[float]: # pragma: nocover
52+
"""Given a list of n results and 1 query message,
53+
return a list-like object of length n that contains the similarity scores between
54+
each item in `results` and the `query_message`.
55+
56+
A high similarity score means the strings are semantically similar to each other.
57+
`query_message` will be loaded in the same order as they appear in `self.configs.query`.
58+
59+
If you need the raw query results from chromadb,
60+
it'll be saved in `self._raw_results` before this method is called.
61+
"""
4262
raise NotImplementedError
63+
64+
def rerank(self, results: QueryResult | dict) -> list[str]:
65+
self._raw_results = cast(QueryResult, results)
66+
query_chunks = self.configs.query
67+
assert query_chunks
68+
assert results["metadatas"] is not None
69+
assert results["documents"] is not None
70+
documents: DefaultDict[str, list[float]] = defaultdict(list)
71+
for query_chunk_idx in range(len(query_chunks)):
72+
chunk_ids = results["ids"][query_chunk_idx]
73+
chunk_metas = results["metadatas"][query_chunk_idx]
74+
chunk_docs = results["documents"][query_chunk_idx]
75+
scores = self.compute_similarity(chunk_docs, query_chunks[query_chunk_idx])
76+
for i, score in enumerate(scores):
77+
if QueryInclude.chunk in self.configs.include:
78+
documents[chunk_ids[i]].append(float(score))
79+
else:
80+
documents[str(chunk_metas[i]["path"])].append(float(score))
81+
82+
logger.debug("Document scores: %s", documents)
83+
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
84+
for key in documents.keys():
85+
documents[key] = heapq.nlargest(top_k, documents[key])
86+
87+
self._raw_results = None
88+
89+
return heapq.nlargest(
90+
self.n_result,
91+
documents.keys(),
92+
key=lambda x: float(numpy.mean(documents[x])),
93+
)
Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import heapq
21
import logging
3-
from collections import defaultdict
4-
from typing import Any, DefaultDict
2+
from typing import Any
53

6-
import numpy
7-
8-
from vectorcode.cli_utils import Config, QueryInclude
4+
from vectorcode.cli_utils import Config
95

106
from .base import RerankerBase
117

@@ -25,9 +21,6 @@ def __init__(
2521
**kwargs: Any,
2622
):
2723
super().__init__(configs)
28-
assert self.configs.query is not None, (
29-
"'configs' should contain the query messages."
30-
)
3124
from sentence_transformers import CrossEncoder
3225

3326
if configs.reranker_params.get("model_name_or_path") is None:
@@ -39,33 +32,8 @@ def __init__(
3932
)
4033
self.model = CrossEncoder(**configs.reranker_params)
4134

42-
def rerank(self, results) -> list[str]:
43-
assert self.configs.query
44-
query_chunks = self.configs.query
45-
assert results["metadatas"] is not None
46-
assert results["documents"] is not None
47-
documents: DefaultDict[str, list[float]] = defaultdict(list)
48-
for query_chunk_idx in range(len(query_chunks)):
49-
chunk_ids = results["ids"][query_chunk_idx]
50-
chunk_metas = results["metadatas"][query_chunk_idx]
51-
chunk_docs = results["documents"][query_chunk_idx]
52-
ranks = self.model.rank(
53-
query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True
54-
)
55-
for rank in ranks:
56-
if QueryInclude.chunk in self.configs.include:
57-
documents[chunk_ids[rank["corpus_id"]]].append(float(rank["score"]))
58-
else:
59-
documents[chunk_metas[rank["corpus_id"]]["path"]].append(
60-
float(rank["score"])
61-
)
62-
logger.debug("Document scores: %s", documents)
63-
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
64-
for key in documents.keys():
65-
documents[key] = heapq.nlargest(top_k, documents[key])
66-
67-
return heapq.nlargest(
68-
self.n_result,
69-
documents.keys(),
70-
key=lambda x: float(numpy.mean(documents[x])),
35+
def compute_similarity(self, results: list[str], query_message: str):
36+
return list(
37+
float(i)
38+
for i in self.model.predict([(chunk, query_message) for chunk in results])
7139
)
Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import heapq
21
import logging
3-
from collections import defaultdict
4-
from typing import Any, DefaultDict
2+
from typing import Any, Sequence
53

6-
import numpy
7-
8-
from vectorcode.cli_utils import Config, QueryInclude
4+
from vectorcode.cli_utils import Config
95

106
from .base import RerankerBase
117

@@ -21,30 +17,11 @@ class NaiveReranker(RerankerBase):
2117
def __init__(self, configs: Config, **kwargs: Any):
2218
super().__init__(configs)
2319

24-
def rerank(self, results) -> list[str]:
25-
assert results["metadatas"] is not None
26-
assert results["distances"] is not None
27-
documents: DefaultDict[str, list[float]] = defaultdict(list)
28-
for query_chunk_idx in range(len(results["ids"])):
29-
chunk_ids = results["ids"][query_chunk_idx]
30-
chunk_metas = results["metadatas"][query_chunk_idx]
31-
chunk_distances = results["distances"][query_chunk_idx]
32-
# NOTE: distances, smaller is better.
33-
paths = [str(meta["path"]) for meta in chunk_metas]
34-
assert len(paths) == len(chunk_distances)
35-
for distance, identifier in zip(
36-
chunk_distances,
37-
chunk_ids if QueryInclude.chunk in self.configs.include else paths,
38-
):
39-
if identifier is None: # pragma: nocover
40-
# so that vectorcode doesn't break on old collections.
41-
continue
42-
documents[identifier].append(distance)
43-
logger.debug("Document scores: %s", documents)
44-
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
45-
for key in documents.keys():
46-
documents[key] = heapq.nsmallest(top_k, documents[key])
47-
48-
return heapq.nsmallest(
49-
self.n_result, documents.keys(), lambda x: float(numpy.mean(documents[x]))
50-
)
20+
def compute_similarity(
21+
self, results: list[str], query_message: str
22+
) -> Sequence[float]:
23+
assert self._raw_results is not None, "Expecting raw results from the database."
24+
assert self._raw_results.get("distances") is not None
25+
assert self.configs.query, "Expecting query messages in self.configs"
26+
idx = self.configs.query.index(query_message)
27+
return list(-i for i in self._raw_results.get("distances")[idx])

tests/subcommands/query/test_reranker.py

Lines changed: 36 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import MagicMock, patch
22

3+
import numpy
34
import pytest
45

56
from vectorcode.cli_utils import Config, QueryInclude
@@ -94,24 +95,6 @@ def test_naive_reranker_rerank(naive_reranker_conf, query_result):
9495
assert isinstance(path, str)
9596

9697

97-
def test_naive_reranker_handles_none_path(config, query_result):
98-
"""Test NaiveReranker properly handles None paths in metadata"""
99-
# Create a copy with a None path
100-
query_result_with_none = query_result.copy()
101-
query_result_with_none["metadatas"] = [
102-
[{"path": "file1.py"}, {"path": None}, {"path": "file3.py"}],
103-
[{"path": "file2.py"}, {"path": "file4.py"}, {"path": "file3.py"}],
104-
]
105-
106-
reranker = NaiveReranker(config)
107-
result = reranker.rerank(query_result_with_none)
108-
109-
# Check the None path was handled without errors
110-
assert isinstance(result, list)
111-
# None should be filtered out
112-
assert None not in result
113-
114-
11598
@patch("sentence_transformers.CrossEncoder")
11699
def test_cross_encoder_reranker_initialization(mock_cross_encoder: MagicMock, config):
117100
reranker = CrossEncoderReranker(config)
@@ -141,43 +124,34 @@ def test_cross_encoder_reranker_rerank(
141124
mock_model = MagicMock()
142125
mock_cross_encoder.return_value = mock_model
143126

144-
# Configure mock rank method to return predetermined ranks
145-
mock_model.rank.return_value = [
146-
{"corpus_id": 0, "score": 0.9},
147-
{"corpus_id": 1, "score": 0.7},
148-
{"corpus_id": 2, "score": 0.8},
149-
]
127+
# Configure mock predict to return numpy array with float32 dtype
128+
scores = numpy.array([0.9, 0.7, 0.8], dtype=numpy.float32)
129+
mock_model.predict.return_value = scores
150130

151-
reranker = CrossEncoderReranker(config)
131+
# Ensure complete query_result structure
132+
query_result.update(
133+
{
134+
"ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
135+
"documents": [["doc1", "doc2", "doc3"], ["doc4", "doc5", "doc6"]],
136+
"metadatas": [
137+
[{"path": "p1"}, {"path": "p2"}, {"path": "p3"}],
138+
[{"path": "p4"}, {"path": "p5"}, {"path": "p6"}],
139+
],
140+
}
141+
)
152142

143+
reranker = CrossEncoderReranker(config)
153144
result = reranker.rerank(query_result)
154145

155-
# Verify the model was called with correct parameters
156-
mock_model.rank.assert_called()
157-
158-
# Check result
146+
# Result assertions
159147
assert isinstance(result, list)
148+
assert all(isinstance(path, str) for path in result)
160149
assert len(result) <= config.n_result
161150

162-
# Check all returned items are strings (paths)
163-
for path in result:
164-
assert isinstance(path, str)
165-
166151

167-
def test_naive_reranker_document_selection_logic(naive_reranker_conf):
152+
def test_naive_reranker_document_selection_logic(naive_reranker_conf, query_result):
168153
"""Test that NaiveReranker correctly selects documents based on distances"""
169154
# Create a query result with known distances
170-
query_result = {
171-
"ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
172-
"distances": [
173-
[0.3, 0.1, 0.2], # file2 has lowest, then file3, then file1
174-
[0.6, 0.4, 0.5], # file4 has lowest, then file3, then file2
175-
],
176-
"metadatas": [
177-
[{"path": "file1.py"}, {"path": "file2.py"}, {"path": "file3.py"}],
178-
[{"path": "file2.py"}, {"path": "file4.py"}, {"path": "file3.py"}],
179-
],
180-
}
181155

182156
reranker = NaiveReranker(naive_reranker_conf)
183157
result = reranker.rerank(query_result)
@@ -188,19 +162,12 @@ def test_naive_reranker_document_selection_logic(naive_reranker_conf):
188162
assert "file2.py" in result or "file3.py" in result
189163

190164

191-
def test_naive_reranker_with_chunk_ids(naive_reranker_conf):
165+
def test_naive_reranker_with_chunk_ids(naive_reranker_conf, query_result):
192166
"""Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set"""
193167
naive_reranker_conf.include.append(
194168
QueryInclude.chunk
195169
) # Assuming QueryInclude.chunk would be "chunk"
196-
query_result = {
197-
"ids": [["id1", "id2"], ["id3", "id1"]],
198-
"distances": [[0.1, 0.2], [0.3, 0.4]],
199-
"metadatas": [
200-
[{"path": "file1.py"}, {"path": "file2.py"}],
201-
[{"path": "file3.py"}, {"path": "file1.py"}],
202-
],
203-
}
170+
204171
reranker = NaiveReranker(naive_reranker_conf)
205172
result = reranker.rerank(query_result)
206173

@@ -212,33 +179,22 @@ def test_naive_reranker_with_chunk_ids(naive_reranker_conf):
212179

213180
@patch("sentence_transformers.CrossEncoder")
214181
def test_cross_encoder_reranker_with_chunk_ids(
215-
mock_cross_encoder, config, query_chunks
182+
mock_cross_encoder, config, query_result
216183
):
217184
"""Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set"""
218185
mock_model = MagicMock()
219186
mock_cross_encoder.return_value = mock_model
220-
mock_model.rank.return_value = [
221-
{"corpus_id": 0, "score": 0.9},
222-
{"corpus_id": 1, "score": 0.7},
223-
]
224-
225-
config.include = {"chunk"} # Use comma instead of append
226-
reranker = CrossEncoderReranker(
227-
config,
228-
)
229187

230-
# Match query_chunks length with results
231-
result = reranker.rerank(
232-
{
233-
"ids": [["id1", "id2"], ["id3", "id4"]], # Two query chunks
234-
"metadatas": [
235-
[{"path": "file1.py"}, {"path": "file2.py"}],
236-
[{"path": "file3.py"}, {"path": "file4.py"}],
237-
],
238-
"documents": [["doc1", "doc2"], ["doc3", "doc4"]],
239-
},
240-
)
188+
# Setup mock to return numpy array scores
189+
scores = numpy.array([0.9, 0.7], dtype=numpy.float32)
190+
mock_model.predict.return_value = scores
191+
192+
config.include = {QueryInclude.chunk}
193+
reranker = CrossEncoderReranker(config)
241194

195+
result = reranker.rerank(query_result)
196+
197+
mock_model.predict.assert_called()
242198
assert isinstance(result, list)
243199
assert all(isinstance(id, str) for id in result)
244200
assert all(id in ["id1", "id2", "id3", "id4"] for id in result)
@@ -275,12 +231,15 @@ def test_add_reranker_success():
275231

276232
@add_reranker
277233
class TestReranker(RerankerBase):
278-
def rerank(self, results, query_chunks):
234+
def compute_similarity(self, results, query_message):
279235
return []
280236

281237
assert len(get_available_rerankers()) == original_count + 1
282238
assert "TestReranker" in __supported_rerankers
283-
assert isinstance(get_reranker(Config(reranker="TestReranker")), TestReranker)
239+
assert isinstance(
240+
get_reranker(Config(reranker="TestReranker", query=["hello world"])),
241+
TestReranker,
242+
)
284243
__supported_rerankers.pop("TestReranker")
285244

286245

0 commit comments

Comments
 (0)