11from unittest .mock import MagicMock , patch
22
3+ import numpy
34import pytest
45
56from vectorcode .cli_utils import Config , QueryInclude
@@ -94,24 +95,6 @@ def test_naive_reranker_rerank(naive_reranker_conf, query_result):
9495 assert isinstance (path , str )
9596
9697
97- def test_naive_reranker_handles_none_path (config , query_result ):
98- """Test NaiveReranker properly handles None paths in metadata"""
99- # Create a copy with a None path
100- query_result_with_none = query_result .copy ()
101- query_result_with_none ["metadatas" ] = [
102- [{"path" : "file1.py" }, {"path" : None }, {"path" : "file3.py" }],
103- [{"path" : "file2.py" }, {"path" : "file4.py" }, {"path" : "file3.py" }],
104- ]
105-
106- reranker = NaiveReranker (config )
107- result = reranker .rerank (query_result_with_none )
108-
109- # Check the None path was handled without errors
110- assert isinstance (result , list )
111- # None should be filtered out
112- assert None not in result
113-
114-
11598@patch ("sentence_transformers.CrossEncoder" )
11699def test_cross_encoder_reranker_initialization (mock_cross_encoder : MagicMock , config ):
117100 reranker = CrossEncoderReranker (config )
@@ -141,43 +124,34 @@ def test_cross_encoder_reranker_rerank(
141124 mock_model = MagicMock ()
142125 mock_cross_encoder .return_value = mock_model
143126
144- # Configure mock rank method to return predetermined ranks
145- mock_model .rank .return_value = [
146- {"corpus_id" : 0 , "score" : 0.9 },
147- {"corpus_id" : 1 , "score" : 0.7 },
148- {"corpus_id" : 2 , "score" : 0.8 },
149- ]
127+ # Configure mock predict to return numpy array with float32 dtype
128+ scores = numpy .array ([0.9 , 0.7 , 0.8 ], dtype = numpy .float32 )
129+ mock_model .predict .return_value = scores
150130
151- reranker = CrossEncoderReranker (config )
131+ # Ensure complete query_result structure
132+ query_result .update (
133+ {
134+ "ids" : [["id1" , "id2" , "id3" ], ["id4" , "id5" , "id6" ]],
135+ "documents" : [["doc1" , "doc2" , "doc3" ], ["doc4" , "doc5" , "doc6" ]],
136+ "metadatas" : [
137+ [{"path" : "p1" }, {"path" : "p2" }, {"path" : "p3" }],
138+ [{"path" : "p4" }, {"path" : "p5" }, {"path" : "p6" }],
139+ ],
140+ }
141+ )
152142
143+ reranker = CrossEncoderReranker (config )
153144 result = reranker .rerank (query_result )
154145
155- # Verify the model was called with correct parameters
156- mock_model .rank .assert_called ()
157-
158- # Check result
146+ # Result assertions
159147 assert isinstance (result , list )
148+ assert all (isinstance (path , str ) for path in result )
160149 assert len (result ) <= config .n_result
161150
162- # Check all returned items are strings (paths)
163- for path in result :
164- assert isinstance (path , str )
165-
166151
167- def test_naive_reranker_document_selection_logic (naive_reranker_conf ):
152+ def test_naive_reranker_document_selection_logic (naive_reranker_conf , query_result ):
168153 """Test that NaiveReranker correctly selects documents based on distances"""
169154 # Create a query result with known distances
170- query_result = {
171- "ids" : [["id1" , "id2" , "id3" ], ["id4" , "id5" , "id6" ]],
172- "distances" : [
173- [0.3 , 0.1 , 0.2 ], # file2 has lowest, then file3, then file1
174- [0.6 , 0.4 , 0.5 ], # file4 has lowest, then file3, then file2
175- ],
176- "metadatas" : [
177- [{"path" : "file1.py" }, {"path" : "file2.py" }, {"path" : "file3.py" }],
178- [{"path" : "file2.py" }, {"path" : "file4.py" }, {"path" : "file3.py" }],
179- ],
180- }
181155
182156 reranker = NaiveReranker (naive_reranker_conf )
183157 result = reranker .rerank (query_result )
@@ -188,19 +162,12 @@ def test_naive_reranker_document_selection_logic(naive_reranker_conf):
188162 assert "file2.py" in result or "file3.py" in result
189163
190164
191- def test_naive_reranker_with_chunk_ids (naive_reranker_conf ):
165+ def test_naive_reranker_with_chunk_ids (naive_reranker_conf , query_result ):
192166 """Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set"""
193167 naive_reranker_conf .include .append (
194168 QueryInclude .chunk
195169 ) # Assuming QueryInclude.chunk would be "chunk"
196- query_result = {
197- "ids" : [["id1" , "id2" ], ["id3" , "id1" ]],
198- "distances" : [[0.1 , 0.2 ], [0.3 , 0.4 ]],
199- "metadatas" : [
200- [{"path" : "file1.py" }, {"path" : "file2.py" }],
201- [{"path" : "file3.py" }, {"path" : "file1.py" }],
202- ],
203- }
170+
204171 reranker = NaiveReranker (naive_reranker_conf )
205172 result = reranker .rerank (query_result )
206173
@@ -212,33 +179,22 @@ def test_naive_reranker_with_chunk_ids(naive_reranker_conf):
212179
213180@patch ("sentence_transformers.CrossEncoder" )
214181def test_cross_encoder_reranker_with_chunk_ids (
215- mock_cross_encoder , config , query_chunks
182+ mock_cross_encoder , config , query_result
216183):
217184 """Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set"""
218185 mock_model = MagicMock ()
219186 mock_cross_encoder .return_value = mock_model
220- mock_model .rank .return_value = [
221- {"corpus_id" : 0 , "score" : 0.9 },
222- {"corpus_id" : 1 , "score" : 0.7 },
223- ]
224-
225- config .include = {"chunk" } # Use comma instead of append
226- reranker = CrossEncoderReranker (
227- config ,
228- )
229187
230- # Match query_chunks length with results
231- result = reranker .rerank (
232- {
233- "ids" : [["id1" , "id2" ], ["id3" , "id4" ]], # Two query chunks
234- "metadatas" : [
235- [{"path" : "file1.py" }, {"path" : "file2.py" }],
236- [{"path" : "file3.py" }, {"path" : "file4.py" }],
237- ],
238- "documents" : [["doc1" , "doc2" ], ["doc3" , "doc4" ]],
239- },
240- )
188+ # Setup mock to return numpy array scores
189+ scores = numpy .array ([0.9 , 0.7 ], dtype = numpy .float32 )
190+ mock_model .predict .return_value = scores
191+
192+ config .include = {QueryInclude .chunk }
193+ reranker = CrossEncoderReranker (config )
241194
195+ result = reranker .rerank (query_result )
196+
197+ mock_model .predict .assert_called ()
242198 assert isinstance (result , list )
243199 assert all (isinstance (id , str ) for id in result )
244200 assert all (id in ["id1" , "id2" , "id3" , "id4" ] for id in result )
@@ -275,12 +231,15 @@ def test_add_reranker_success():
275231
276232 @add_reranker
277233 class TestReranker (RerankerBase ):
278- def rerank (self , results , query_chunks ):
234+ def compute_similarity (self , results , query_message ):
279235 return []
280236
281237 assert len (get_available_rerankers ()) == original_count + 1
282238 assert "TestReranker" in __supported_rerankers
283- assert isinstance (get_reranker (Config (reranker = "TestReranker" )), TestReranker )
239+ assert isinstance (
240+ get_reranker (Config (reranker = "TestReranker" , query = ["hello world" ])),
241+ TestReranker ,
242+ )
284243 __supported_rerankers .pop ("TestReranker" )
285244
286245
0 commit comments