Skip to content

Commit 4a8d9bd

Browse files
committed
RerankerBase.rerank doesn't need the query messages as a parameter.
1 parent 871c3fc commit 4a8d9bd

6 files changed

Lines changed: 32 additions & 24 deletions

File tree

src/vectorcode/subcommands/query/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def get_query_result_files(
6868
return []
6969

7070
reranker = get_reranker(configs)
71-
return reranker.rerank(results, query_chunks)
71+
return reranker.rerank(results)
7272

7373

7474
async def build_query_results(

src/vectorcode/subcommands/query/reranker/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,5 @@ def create(cls, configs: Config, **kwargs: Any):
3838
raise
3939

4040
@abstractmethod
41-
def rerank(
42-
self, results: QueryResult, query_chunks: list[str]
43-
) -> list[str]: # pragma: nocover
41+
def rerank(self, results: QueryResult) -> list[str]: # pragma: nocover
4442
raise NotImplementedError

src/vectorcode/subcommands/query/reranker/cross_encoder.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ def __init__(
2525
**kwargs: Any,
2626
):
2727
super().__init__(configs)
28+
assert self.configs.query is not None, (
29+
"'configs' should contain the query messages."
30+
)
2831
from sentence_transformers import CrossEncoder
2932

3033
if configs.reranker_params.get("model_name_or_path") is None:
@@ -36,17 +39,18 @@ def __init__(
3639
)
3740
self.model = CrossEncoder(**configs.reranker_params)
3841

39-
def rerank(self, results, query_chunks) -> list[str]:
40-
self.query_chunks = query_chunks
42+
def rerank(self, results) -> list[str]:
43+
assert self.configs.query
44+
query_chunks = self.configs.query
4145
assert results["metadatas"] is not None
4246
assert results["documents"] is not None
4347
documents: DefaultDict[str, list[float]] = defaultdict(list)
44-
for query_chunk_idx in range(len(self.query_chunks)):
48+
for query_chunk_idx in range(len(query_chunks)):
4549
chunk_ids = results["ids"][query_chunk_idx]
4650
chunk_metas = results["metadatas"][query_chunk_idx]
4751
chunk_docs = results["documents"][query_chunk_idx]
4852
ranks = self.model.rank(
49-
self.query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True
53+
query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True
5054
)
5155
for rank in ranks:
5256
if QueryInclude.chunk in self.configs.include:

src/vectorcode/subcommands/query/reranker/naive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class NaiveReranker(RerankerBase):
2121
def __init__(self, configs: Config, **kwargs: Any):
2222
super().__init__(configs)
2323

24-
def rerank(self, results, query_chunks) -> list[str]:
24+
def rerank(self, results) -> list[str]:
2525
assert results["metadatas"] is not None
2626
assert results["distances"] is not None
2727
documents: DefaultDict[str, list[float]] = defaultdict(list)

tests/subcommands/query/test_query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def test_get_query_result_files(mock_collection, mock_config):
8989
# Check reranker was used correctly
9090
mock_get_reranker.assert_called_once_with(mock_config)
9191
mock_reranker_instance.rerank.assert_called_once_with(
92-
mock_collection.query.return_value, mock_config.query
92+
mock_collection.query.return_value
9393
)
9494

9595
# Check the result

tests/subcommands/query/test_reranker.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, configs, **kwargs):
8383
def test_naive_reranker_rerank(naive_reranker_conf, query_result):
8484
"""Test basic reranking functionality of NaiveReranker"""
8585
reranker = NaiveReranker(naive_reranker_conf)
86-
result = reranker.rerank(query_result, ["foo", "bar"])
86+
result = reranker.rerank(query_result)
8787

8888
# Check the result is a list of paths with correct length
8989
assert isinstance(result, list)
@@ -104,7 +104,7 @@ def test_naive_reranker_handles_none_path(config, query_result):
104104
]
105105

106106
reranker = NaiveReranker(config)
107-
result = reranker.rerank(query_result_with_none, ["foo", "bar"])
107+
result = reranker.rerank(query_result_with_none)
108108

109109
# Check the None path was handled without errors
110110
assert isinstance(result, list)
@@ -121,6 +121,19 @@ def test_cross_encoder_reranker_initialization(mock_cross_encoder: MagicMock, co
121121
assert reranker.n_result == config.n_result
122122

123123

124+
@patch("sentence_transformers.CrossEncoder")
125+
def test_cross_encoder_reranker_initialization_fallback_model_name(
126+
mock_cross_encoder: MagicMock, config
127+
):
128+
expected_params = {"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}
129+
config.reranker_params = {}
130+
reranker = CrossEncoderReranker(config)
131+
132+
# Verify constructor was called with correct parameters
133+
mock_cross_encoder.assert_called_once_with(**expected_params)
134+
assert reranker.n_result == config.n_result
135+
136+
124137
@patch("sentence_transformers.CrossEncoder")
125138
def test_cross_encoder_reranker_rerank(
126139
mock_cross_encoder, config, query_result, query_chunks
@@ -137,7 +150,7 @@ def test_cross_encoder_reranker_rerank(
137150

138151
reranker = CrossEncoderReranker(config)
139152

140-
result = reranker.rerank(query_result, query_chunks)
153+
result = reranker.rerank(query_result)
141154

142155
# Verify the model was called with correct parameters
143156
mock_model.rank.assert_called()
@@ -167,7 +180,7 @@ def test_naive_reranker_document_selection_logic(naive_reranker_conf):
167180
}
168181

169182
reranker = NaiveReranker(naive_reranker_conf)
170-
result = reranker.rerank(query_result, naive_reranker_conf.query)
183+
result = reranker.rerank(query_result)
171184

172185
# Check that files are included (exact order depends on implementation details)
173186
assert len(result) > 0
@@ -189,7 +202,7 @@ def test_naive_reranker_with_chunk_ids(naive_reranker_conf):
189202
],
190203
}
191204
reranker = NaiveReranker(naive_reranker_conf)
192-
result = reranker.rerank(query_result, naive_reranker_conf.query)
205+
result = reranker.rerank(query_result)
193206

194207
assert isinstance(result, list)
195208
assert len(result) <= naive_reranker_conf.n_result
@@ -224,26 +237,19 @@ def test_cross_encoder_reranker_with_chunk_ids(
224237
],
225238
"documents": [["doc1", "doc2"], ["doc3", "doc4"]],
226239
},
227-
config.query,
228240
)
229241

230242
assert isinstance(result, list)
231243
assert all(isinstance(id, str) for id in result)
232244
assert all(id in ["id1", "id2", "id3", "id4"] for id in result)
233245

234246

235-
def test_get_reranker():
236-
config = Config(reranker="NaiveReranker")
237-
assert get_reranker(config).configs.reranker == "NaiveReranker"
247+
def test_get_reranker(config, naive_reranker_conf):
248+
assert get_reranker(naive_reranker_conf).configs.reranker == "NaiveReranker"
238249

239-
config = Config(reranker="CrossEncoderReranker", reranker_params={"device": "cpu"})
240250
reranker = get_reranker(config)
241251
assert reranker.configs.reranker == "CrossEncoderReranker"
242252

243-
config = Config(
244-
reranker="cross-encoder/ms-marco-MiniLM-L-6-v2",
245-
reranker_params={"device": "cpu"},
246-
)
247253
reranker = get_reranker(config)
248254
assert reranker.configs.reranker == "CrossEncoderReranker", (
249255
"configs.reranker should fallback to 'CrossEncoderReranker'"

0 commit comments

Comments
 (0)