Skip to content

Commit d61a68a

Browse files
committed
simplify the process of building a new reranker by creating a compute_similarity method
1 parent 4a8d9bd commit d61a68a

4 files changed

Lines changed: 100 additions & 151 deletions

File tree

src/vectorcode/subcommands/query/reranker/base.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import heapq
12
import logging
23
from abc import ABC, abstractmethod
3-
from typing import Any
4+
from collections import defaultdict
5+
from typing import Any, DefaultDict, Optional, Sequence, cast
46

7+
import numpy
58
from chromadb.api.types import QueryResult
69

7-
from vectorcode.cli_utils import Config
10+
from vectorcode.cli_utils import Config, QueryInclude
811

912
logger = logging.getLogger(name=__name__)
1013

@@ -21,7 +24,11 @@ class RerankerBase(ABC):
2124

2225
def __init__(self, configs: Config, **kwargs: Any):
2326
self.configs = configs
27+
assert self.configs.query is not None, (
28+
"'configs' should contain the query messages."
29+
)
2430
self.n_result = configs.n_result
31+
self._raw_results: Optional[QueryResult] = None
2532

2633
@classmethod
2734
def create(cls, configs: Config, **kwargs: Any):
@@ -38,5 +45,43 @@ def create(cls, configs: Config, **kwargs: Any):
3845
raise
3946

4047
@abstractmethod
41-
def rerank(self, results: QueryResult) -> list[str]: # pragma: nocover
48+
def compute_similarity(
49+
self, results: list[str], query_message: str
50+
) -> Sequence: # pragma: nocover
51+
"""Given a list of n results and 1 query message,
52+
return a list-like object of length n that contains the similarity scores between
53+
each item in `results` and the `query_message`.
54+
55+
A high similarity score means the strings are semantically similar to each other.
56+
`query_message` will be loaded in the same order as they appear in `self.configs.query`
57+
"""
4258
raise NotImplementedError
59+
60+
def rerank(self, results: QueryResult | dict) -> list[str]:
61+
self._raw_results = cast(QueryResult, results)
62+
query_chunks = self.configs.query
63+
assert query_chunks
64+
assert results["metadatas"] is not None
65+
assert results["documents"] is not None
66+
documents: DefaultDict[str, list[float]] = defaultdict(list)
67+
for query_chunk_idx in range(len(query_chunks)):
68+
chunk_ids = results["ids"][query_chunk_idx]
69+
chunk_metas = results["metadatas"][query_chunk_idx]
70+
chunk_docs = results["documents"][query_chunk_idx]
71+
scores = self.compute_similarity(chunk_docs, query_chunks[query_chunk_idx])
72+
for i, score in enumerate(scores):
73+
if QueryInclude.chunk in self.configs.include:
74+
documents[chunk_ids[i]].append(float(score))
75+
else:
76+
documents[chunk_metas[i]["path"]].append(float(score))
77+
78+
logger.debug("Document scores: %s", documents)
79+
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
80+
for key in documents.keys():
81+
documents[key] = heapq.nlargest(top_k, documents[key])
82+
83+
return heapq.nlargest(
84+
self.n_result,
85+
documents.keys(),
86+
key=lambda x: float(numpy.mean(documents[x])),
87+
)
Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import heapq
21
import logging
3-
from collections import defaultdict
4-
from typing import Any, DefaultDict
2+
from typing import Any
53

6-
import numpy
7-
8-
from vectorcode.cli_utils import Config, QueryInclude
4+
from vectorcode.cli_utils import Config
95

106
from .base import RerankerBase
117

@@ -25,9 +21,6 @@ def __init__(
2521
**kwargs: Any,
2622
):
2723
super().__init__(configs)
28-
assert self.configs.query is not None, (
29-
"'configs' should contain the query messages."
30-
)
3124
from sentence_transformers import CrossEncoder
3225

3326
if configs.reranker_params.get("model_name_or_path") is None:
@@ -39,33 +32,8 @@ def __init__(
3932
)
4033
self.model = CrossEncoder(**configs.reranker_params)
4134

42-
def rerank(self, results) -> list[str]:
43-
assert self.configs.query
44-
query_chunks = self.configs.query
45-
assert results["metadatas"] is not None
46-
assert results["documents"] is not None
47-
documents: DefaultDict[str, list[float]] = defaultdict(list)
48-
for query_chunk_idx in range(len(query_chunks)):
49-
chunk_ids = results["ids"][query_chunk_idx]
50-
chunk_metas = results["metadatas"][query_chunk_idx]
51-
chunk_docs = results["documents"][query_chunk_idx]
52-
ranks = self.model.rank(
53-
query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True
54-
)
55-
for rank in ranks:
56-
if QueryInclude.chunk in self.configs.include:
57-
documents[chunk_ids[rank["corpus_id"]]].append(float(rank["score"]))
58-
else:
59-
documents[chunk_metas[rank["corpus_id"]]["path"]].append(
60-
float(rank["score"])
61-
)
62-
logger.debug("Document scores: %s", documents)
63-
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
64-
for key in documents.keys():
65-
documents[key] = heapq.nlargest(top_k, documents[key])
66-
67-
return heapq.nlargest(
68-
self.n_result,
69-
documents.keys(),
70-
key=lambda x: float(numpy.mean(documents[x])),
35+
def compute_similarity(self, results: list[str], query_message: str):
36+
return list(
37+
float(i)
38+
for i in self.model.predict([(chunk, query_message) for chunk in results])
7139
)
Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import heapq
21
import logging
3-
from collections import defaultdict
4-
from typing import Any, DefaultDict
2+
from typing import Any, Sequence
53

6-
import numpy
7-
8-
from vectorcode.cli_utils import Config, QueryInclude
4+
from vectorcode.cli_utils import Config
95

106
from .base import RerankerBase
117

@@ -21,30 +17,11 @@ class NaiveReranker(RerankerBase):
2117
def __init__(self, configs: Config, **kwargs: Any):
2218
super().__init__(configs)
2319

24-
def rerank(self, results) -> list[str]:
25-
assert results["metadatas"] is not None
26-
assert results["distances"] is not None
27-
documents: DefaultDict[str, list[float]] = defaultdict(list)
28-
for query_chunk_idx in range(len(results["ids"])):
29-
chunk_ids = results["ids"][query_chunk_idx]
30-
chunk_metas = results["metadatas"][query_chunk_idx]
31-
chunk_distances = results["distances"][query_chunk_idx]
32-
# NOTE: distances, smaller is better.
33-
paths = [str(meta["path"]) for meta in chunk_metas]
34-
assert len(paths) == len(chunk_distances)
35-
for distance, identifier in zip(
36-
chunk_distances,
37-
chunk_ids if QueryInclude.chunk in self.configs.include else paths,
38-
):
39-
if identifier is None: # pragma: nocover
40-
# so that vectorcode doesn't break on old collections.
41-
continue
42-
documents[identifier].append(distance)
43-
logger.debug("Document scores: %s", documents)
44-
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
45-
for key in documents.keys():
46-
documents[key] = heapq.nsmallest(top_k, documents[key])
47-
48-
return heapq.nsmallest(
49-
self.n_result, documents.keys(), lambda x: float(numpy.mean(documents[x]))
50-
)
20+
def compute_similarity(
21+
self, results: list[str], query_message: str
22+
) -> Sequence[float]:
23+
assert self._raw_results is not None, "Expecting raw results from the database."
24+
assert self._raw_results.get("distances") is not None
25+
assert self.configs.query, "Expecting query messages in self.configs"
26+
idx = self.configs.query.index(query_message)
27+
return list(-i for i in self._raw_results.get("distances")[idx])

tests/subcommands/query/test_reranker.py

Lines changed: 36 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest.mock import MagicMock, patch
22

3+
import numpy
34
import pytest
45

56
from vectorcode.cli_utils import Config, QueryInclude
@@ -94,24 +95,6 @@ def test_naive_reranker_rerank(naive_reranker_conf, query_result):
9495
assert isinstance(path, str)
9596

9697

97-
def test_naive_reranker_handles_none_path(config, query_result):
98-
"""Test NaiveReranker properly handles None paths in metadata"""
99-
# Create a copy with a None path
100-
query_result_with_none = query_result.copy()
101-
query_result_with_none["metadatas"] = [
102-
[{"path": "file1.py"}, {"path": None}, {"path": "file3.py"}],
103-
[{"path": "file2.py"}, {"path": "file4.py"}, {"path": "file3.py"}],
104-
]
105-
106-
reranker = NaiveReranker(config)
107-
result = reranker.rerank(query_result_with_none)
108-
109-
# Check the None path was handled without errors
110-
assert isinstance(result, list)
111-
# None should be filtered out
112-
assert None not in result
113-
114-
11598
@patch("sentence_transformers.CrossEncoder")
11699
def test_cross_encoder_reranker_initialization(mock_cross_encoder: MagicMock, config):
117100
reranker = CrossEncoderReranker(config)
@@ -141,43 +124,34 @@ def test_cross_encoder_reranker_rerank(
141124
mock_model = MagicMock()
142125
mock_cross_encoder.return_value = mock_model
143126

144-
# Configure mock rank method to return predetermined ranks
145-
mock_model.rank.return_value = [
146-
{"corpus_id": 0, "score": 0.9},
147-
{"corpus_id": 1, "score": 0.7},
148-
{"corpus_id": 2, "score": 0.8},
149-
]
127+
# Configure mock predict to return numpy array with float32 dtype
128+
scores = numpy.array([0.9, 0.7, 0.8], dtype=numpy.float32)
129+
mock_model.predict.return_value = scores
150130

151-
reranker = CrossEncoderReranker(config)
131+
# Ensure complete query_result structure
132+
query_result.update(
133+
{
134+
"ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
135+
"documents": [["doc1", "doc2", "doc3"], ["doc4", "doc5", "doc6"]],
136+
"metadatas": [
137+
[{"path": "p1"}, {"path": "p2"}, {"path": "p3"}],
138+
[{"path": "p4"}, {"path": "p5"}, {"path": "p6"}],
139+
],
140+
}
141+
)
152142

143+
reranker = CrossEncoderReranker(config)
153144
result = reranker.rerank(query_result)
154145

155-
# Verify the model was called with correct parameters
156-
mock_model.rank.assert_called()
157-
158-
# Check result
146+
# Result assertions
159147
assert isinstance(result, list)
148+
assert all(isinstance(path, str) for path in result)
160149
assert len(result) <= config.n_result
161150

162-
# Check all returned items are strings (paths)
163-
for path in result:
164-
assert isinstance(path, str)
165-
166151

167-
def test_naive_reranker_document_selection_logic(naive_reranker_conf):
152+
def test_naive_reranker_document_selection_logic(naive_reranker_conf, query_result):
168153
"""Test that NaiveReranker correctly selects documents based on distances"""
169154
# Create a query result with known distances
170-
query_result = {
171-
"ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
172-
"distances": [
173-
[0.3, 0.1, 0.2], # file2 has lowest, then file3, then file1
174-
[0.6, 0.4, 0.5], # file4 has lowest, then file3, then file2
175-
],
176-
"metadatas": [
177-
[{"path": "file1.py"}, {"path": "file2.py"}, {"path": "file3.py"}],
178-
[{"path": "file2.py"}, {"path": "file4.py"}, {"path": "file3.py"}],
179-
],
180-
}
181155

182156
reranker = NaiveReranker(naive_reranker_conf)
183157
result = reranker.rerank(query_result)
@@ -188,19 +162,12 @@ def test_naive_reranker_document_selection_logic(naive_reranker_conf):
188162
assert "file2.py" in result or "file3.py" in result
189163

190164

191-
def test_naive_reranker_with_chunk_ids(naive_reranker_conf):
165+
def test_naive_reranker_with_chunk_ids(naive_reranker_conf, query_result):
192166
"""Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set"""
193167
naive_reranker_conf.include.append(
194168
QueryInclude.chunk
195169
) # Assuming QueryInclude.chunk would be "chunk"
196-
query_result = {
197-
"ids": [["id1", "id2"], ["id3", "id1"]],
198-
"distances": [[0.1, 0.2], [0.3, 0.4]],
199-
"metadatas": [
200-
[{"path": "file1.py"}, {"path": "file2.py"}],
201-
[{"path": "file3.py"}, {"path": "file1.py"}],
202-
],
203-
}
170+
204171
reranker = NaiveReranker(naive_reranker_conf)
205172
result = reranker.rerank(query_result)
206173

@@ -212,33 +179,22 @@ def test_naive_reranker_with_chunk_ids(naive_reranker_conf):
212179

213180
@patch("sentence_transformers.CrossEncoder")
214181
def test_cross_encoder_reranker_with_chunk_ids(
215-
mock_cross_encoder, config, query_chunks
182+
mock_cross_encoder, config, query_result
216183
):
217184
"""Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set"""
218185
mock_model = MagicMock()
219186
mock_cross_encoder.return_value = mock_model
220-
mock_model.rank.return_value = [
221-
{"corpus_id": 0, "score": 0.9},
222-
{"corpus_id": 1, "score": 0.7},
223-
]
224-
225-
config.include = {"chunk"} # Use comma instead of append
226-
reranker = CrossEncoderReranker(
227-
config,
228-
)
229187

230-
# Match query_chunks length with results
231-
result = reranker.rerank(
232-
{
233-
"ids": [["id1", "id2"], ["id3", "id4"]], # Two query chunks
234-
"metadatas": [
235-
[{"path": "file1.py"}, {"path": "file2.py"}],
236-
[{"path": "file3.py"}, {"path": "file4.py"}],
237-
],
238-
"documents": [["doc1", "doc2"], ["doc3", "doc4"]],
239-
},
240-
)
188+
# Setup mock to return numpy array scores
189+
scores = numpy.array([0.9, 0.7], dtype=numpy.float32)
190+
mock_model.predict.return_value = scores
191+
192+
config.include = {QueryInclude.chunk}
193+
reranker = CrossEncoderReranker(config)
241194

195+
result = reranker.rerank(query_result)
196+
197+
mock_model.predict.assert_called()
242198
assert isinstance(result, list)
243199
assert all(isinstance(id, str) for id in result)
244200
assert all(id in ["id1", "id2", "id3", "id4"] for id in result)
@@ -275,12 +231,15 @@ def test_add_reranker_success():
275231

276232
@add_reranker
277233
class TestReranker(RerankerBase):
278-
def rerank(self, results, query_chunks):
234+
def compute_similarity(self, results, query_message):
279235
return []
280236

281237
assert len(get_available_rerankers()) == original_count + 1
282238
assert "TestReranker" in __supported_rerankers
283-
assert isinstance(get_reranker(Config(reranker="TestReranker")), TestReranker)
239+
assert isinstance(
240+
get_reranker(Config(reranker="TestReranker", query=["hello world"])),
241+
TestReranker,
242+
)
284243
__supported_rerankers.pop("TestReranker")
285244

286245

0 commit comments

Comments
 (0)