Skip to content

Commit 657bef4

Browse files
author
Zhe Yu
committed
make rerank and compute_similarity async.
1 parent 498ec61 commit 657bef4

6 files changed

Lines changed: 43 additions & 33 deletions

File tree

src/vectorcode/subcommands/query/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def get_query_result_files(
6868
return []
6969

7070
reranker = get_reranker(configs)
71-
return reranker.rerank(results)
71+
return await reranker.rerank(results)
7272

7373

7474
async def build_query_results(

src/vectorcode/subcommands/query/reranker/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create(cls, configs: Config, **kwargs: Any):
4646
raise
4747

4848
@abstractmethod
49-
def compute_similarity(
49+
async def compute_similarity(
5050
self, results: list[str], query_message: str
5151
) -> Sequence[float]: # pragma: nocover
5252
"""Given a list of n results and 1 query message,
@@ -61,7 +61,7 @@ def compute_similarity(
6161
"""
6262
raise NotImplementedError
6363

64-
def rerank(self, results: QueryResult | dict) -> list[str]:
64+
async def rerank(self, results: QueryResult | dict) -> list[str]:
6565
self._raw_results = cast(QueryResult, results)
6666
query_chunks = self.configs.query
6767
assert query_chunks
@@ -72,7 +72,9 @@ def rerank(self, results: QueryResult | dict) -> list[str]:
7272
chunk_ids = results["ids"][query_chunk_idx]
7373
chunk_metas = results["metadatas"][query_chunk_idx]
7474
chunk_docs = results["documents"][query_chunk_idx]
75-
scores = self.compute_similarity(chunk_docs, query_chunks[query_chunk_idx])
75+
scores = await self.compute_similarity(
76+
chunk_docs, query_chunks[query_chunk_idx]
77+
)
7678
for i, score in enumerate(scores):
7779
if QueryInclude.chunk in self.configs.include:
7880
documents[chunk_ids[i]].append(float(score))

src/vectorcode/subcommands/query/reranker/cross_encoder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import Any
34

@@ -32,8 +33,8 @@ def __init__(
3233
)
3334
self.model = CrossEncoder(**configs.reranker_params)
3435

35-
def compute_similarity(self, results: list[str], query_message: str):
36-
return list(
37-
float(i)
38-
for i in self.model.predict([(chunk, query_message) for chunk in results])
36+
async def compute_similarity(self, results: list[str], query_message: str):
37+
scores = await asyncio.to_thread(
38+
self.model.predict, [(chunk, query_message) for chunk in results]
3939
)
40+
return list(float(i) for i in scores)

src/vectorcode/subcommands/query/reranker/naive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class NaiveReranker(RerankerBase):
1717
def __init__(self, configs: Config, **kwargs: Any):
1818
super().__init__(configs)
1919

20-
def compute_similarity(
20+
async def compute_similarity(
2121
self, results: list[str], query_message: str
2222
) -> Sequence[float]:
2323
assert self._raw_results is not None, "Expecting raw results from the database."

tests/subcommands/query/test_query.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,13 @@ async def test_get_query_result_files(mock_collection, mock_config):
6464
# Mock the reranker
6565
with patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker:
6666
mock_reranker_instance = MagicMock()
67-
mock_reranker_instance.rerank.return_value = [
68-
"file1.py",
69-
"file2.py",
70-
"file3.py",
71-
]
67+
mock_reranker_instance.rerank = AsyncMock(
68+
return_value=[
69+
"file1.py",
70+
"file2.py",
71+
"file3.py",
72+
]
73+
)
7274
mock_get_reranker.return_value = mock_reranker_instance
7375

7476
# Call the function
@@ -103,7 +105,7 @@ async def test_get_query_result_files_include_chunk(mock_collection, mock_config
103105

104106
with patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker:
105107
mock_reranker_instance = MagicMock()
106-
mock_reranker_instance.rerank.return_value = ["chunk1"]
108+
mock_reranker_instance.rerank = AsyncMock(return_value=["chunk1"])
107109
MockReranker.return_value = mock_reranker_instance
108110

109111
await get_query_result_files(mock_collection, mock_config)
@@ -188,7 +190,7 @@ async def test_get_query_result_files_with_query_exclude(mock_collection, mock_c
188190
mock_expand_path.return_value = "/excluded/path.py"
189191

190192
mock_reranker_instance = MagicMock()
191-
mock_reranker_instance.rerank.return_value = ["file1.py", "file2.py"]
193+
mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"])
192194
MockReranker.return_value = mock_reranker_instance
193195

194196
# Call the function
@@ -211,7 +213,7 @@ async def test_get_query_result_files_with_cross_encoder_reranker(
211213
"vectorcode.subcommands.query.reranker.CrossEncoderReranker"
212214
) as MockCrossEncoder:
213215
mock_reranker_instance = MagicMock()
214-
mock_reranker_instance.rerank.return_value = ["file1.py", "file2.py"]
216+
mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"])
215217
MockCrossEncoder.return_value = mock_reranker_instance
216218

217219
# Call the function
@@ -266,7 +268,7 @@ async def test_get_query_result_files_chunking(mock_collection, mock_config):
266268
MockChunker.return_value = mock_chunker_instance
267269

268270
mock_reranker_instance = MagicMock()
269-
mock_reranker_instance.rerank.return_value = ["file1.py", "file2.py"]
271+
mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"])
270272
MockReranker.return_value = mock_reranker_instance
271273

272274
# Call the function
@@ -300,7 +302,7 @@ async def test_get_query_result_files_multiple_queries(mock_collection, mock_con
300302
MockChunker.return_value = mock_chunker_instance
301303

302304
mock_reranker_instance = MagicMock()
303-
mock_reranker_instance.rerank.return_value = ["file1.py", "file2.py"]
305+
mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"])
304306
MockReranker.return_value = mock_reranker_instance
305307

306308
# Call the function

tests/subcommands/query/test_reranker.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,11 @@ def __init__(self, configs, **kwargs):
8181
TestReranker.create(Config())
8282

8383

84-
def test_naive_reranker_rerank(naive_reranker_conf, query_result):
84+
@pytest.mark.asyncio
85+
async def test_naive_reranker_rerank(naive_reranker_conf, query_result):
8586
"""Test basic reranking functionality of NaiveReranker"""
8687
reranker = NaiveReranker(naive_reranker_conf)
87-
result = reranker.rerank(query_result)
88+
result = await reranker.rerank(query_result)
8889

8990
# Check the result is a list of paths with correct length
9091
assert isinstance(result, list)
@@ -117,10 +118,9 @@ def test_cross_encoder_reranker_initialization_fallback_model_name(
117118
assert reranker.n_result == config.n_result
118119

119120

121+
@pytest.mark.asyncio
120122
@patch("sentence_transformers.CrossEncoder")
121-
def test_cross_encoder_reranker_rerank(
122-
mock_cross_encoder, config, query_result, query_chunks
123-
):
123+
async def test_cross_encoder_reranker_rerank(mock_cross_encoder, config, query_result):
124124
mock_model = MagicMock()
125125
mock_cross_encoder.return_value = mock_model
126126

@@ -141,44 +141,49 @@ def test_cross_encoder_reranker_rerank(
141141
)
142142

143143
reranker = CrossEncoderReranker(config)
144-
result = reranker.rerank(query_result)
144+
result = await reranker.rerank(query_result)
145145

146146
# Result assertions
147147
assert isinstance(result, list)
148148
assert all(isinstance(path, str) for path in result)
149149
assert len(result) <= config.n_result
150150

151151

152-
def test_naive_reranker_document_selection_logic(naive_reranker_conf, query_result):
152+
@pytest.mark.asyncio
153+
async def test_naive_reranker_document_selection_logic(
154+
naive_reranker_conf, query_result
155+
):
153156
"""Test that NaiveReranker correctly selects documents based on distances"""
154157
# Create a query result with known distances
155158

156159
reranker = NaiveReranker(naive_reranker_conf)
157-
result = reranker.rerank(query_result)
160+
result = await reranker.rerank(query_result)
158161

159162
# Check that files are included (exact order depends on implementation details)
160163
assert len(result) > 0
161164
# Common files should be present
162165
assert "file2.py" in result or "file3.py" in result
163166

164167

165-
def test_naive_reranker_with_chunk_ids(naive_reranker_conf, query_result):
168+
@pytest.mark.asyncio
169+
async def test_naive_reranker_with_chunk_ids(naive_reranker_conf, query_result):
166170
"""Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set"""
167171
naive_reranker_conf.include.append(
168172
QueryInclude.chunk
169173
) # Assuming QueryInclude.chunk would be "chunk"
170174

171175
reranker = NaiveReranker(naive_reranker_conf)
172-
result = reranker.rerank(query_result)
176+
result = await reranker.rerank(query_result)
173177

174178
assert isinstance(result, list)
175179
assert len(result) <= naive_reranker_conf.n_result
176180
assert all(isinstance(id, str) for id in result)
177181
assert all(id.startswith("id") for id in result) # Verify IDs not paths
178182

179183

184+
@pytest.mark.asyncio
180185
@patch("sentence_transformers.CrossEncoder")
181-
def test_cross_encoder_reranker_with_chunk_ids(
186+
async def test_cross_encoder_reranker_with_chunk_ids(
182187
mock_cross_encoder, config, query_result
183188
):
184189
"""Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set"""
@@ -192,7 +197,7 @@ def test_cross_encoder_reranker_with_chunk_ids(
192197
config.include = {QueryInclude.chunk}
193198
reranker = CrossEncoderReranker(config)
194199

195-
result = reranker.rerank(query_result)
200+
result = await reranker.rerank(query_result)
196201

197202
mock_model.predict.assert_called()
198203
assert isinstance(result, list)
@@ -231,7 +236,7 @@ def test_add_reranker_success():
231236

232237
@add_reranker
233238
class TestReranker(RerankerBase):
234-
def compute_similarity(self, results, query_message):
239+
async def compute_similarity(self, results, query_message):
235240
return []
236241

237242
assert len(get_available_rerankers()) == original_count + 1
@@ -249,7 +254,7 @@ def test_add_reranker_duplicate():
249254
# First registration should succeed
250255
@add_reranker
251256
class TestReranker(RerankerBase):
252-
def rerank(self, results, query_chunks):
257+
async def compute_similarity(self, results, query_message):
253258
return []
254259

255260
# Second registration should fail

0 commit comments

Comments
 (0)