@@ -81,10 +81,11 @@ def __init__(self, configs, **kwargs):
8181 TestReranker .create (Config ())
8282
8383
84- def test_naive_reranker_rerank (naive_reranker_conf , query_result ):
84+ @pytest .mark .asyncio
85+ async def test_naive_reranker_rerank (naive_reranker_conf , query_result ):
8586 """Test basic reranking functionality of NaiveReranker"""
8687 reranker = NaiveReranker (naive_reranker_conf )
87- result = reranker .rerank (query_result )
88+ result = await reranker .rerank (query_result )
8889
8990 # Check the result is a list of paths with correct length
9091 assert isinstance (result , list )
@@ -117,10 +118,9 @@ def test_cross_encoder_reranker_initialization_fallback_model_name(
117118 assert reranker .n_result == config .n_result
118119
119120
121+ @pytest .mark .asyncio
120122@patch ("sentence_transformers.CrossEncoder" )
121- def test_cross_encoder_reranker_rerank (
122- mock_cross_encoder , config , query_result , query_chunks
123- ):
123+ async def test_cross_encoder_reranker_rerank (mock_cross_encoder , config , query_result ):
124124 mock_model = MagicMock ()
125125 mock_cross_encoder .return_value = mock_model
126126
@@ -141,44 +141,49 @@ def test_cross_encoder_reranker_rerank(
141141 )
142142
143143 reranker = CrossEncoderReranker (config )
144- result = reranker .rerank (query_result )
144+ result = await reranker .rerank (query_result )
145145
146146 # Result assertions
147147 assert isinstance (result , list )
148148 assert all (isinstance (path , str ) for path in result )
149149 assert len (result ) <= config .n_result
150150
151151
152- def test_naive_reranker_document_selection_logic (naive_reranker_conf , query_result ):
152+ @pytest .mark .asyncio
153+ async def test_naive_reranker_document_selection_logic (
154+ naive_reranker_conf , query_result
155+ ):
153156 """Test that NaiveReranker correctly selects documents based on distances"""
154157 # Create a query result with known distances
155158
156159 reranker = NaiveReranker (naive_reranker_conf )
157- result = reranker .rerank (query_result )
160+ result = await reranker .rerank (query_result )
158161
159162 # Check that files are included (exact order depends on implementation details)
160163 assert len (result ) > 0
161164 # Common files should be present
162165 assert "file2.py" in result or "file3.py" in result
163166
164167
165- def test_naive_reranker_with_chunk_ids (naive_reranker_conf , query_result ):
168+ @pytest .mark .asyncio
169+ async def test_naive_reranker_with_chunk_ids (naive_reranker_conf , query_result ):
166170 """Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set"""
167171 naive_reranker_conf .include .append (
168172 QueryInclude .chunk
169173 ) # Assuming QueryInclude.chunk would be "chunk"
170174
171175 reranker = NaiveReranker (naive_reranker_conf )
172- result = reranker .rerank (query_result )
176+ result = await reranker .rerank (query_result )
173177
174178 assert isinstance (result , list )
175179 assert len (result ) <= naive_reranker_conf .n_result
176180 assert all (isinstance (id , str ) for id in result )
177181 assert all (id .startswith ("id" ) for id in result ) # Verify IDs not paths
178182
179183
184+ @pytest .mark .asyncio
180185@patch ("sentence_transformers.CrossEncoder" )
181- def test_cross_encoder_reranker_with_chunk_ids (
186+ async def test_cross_encoder_reranker_with_chunk_ids (
182187 mock_cross_encoder , config , query_result
183188):
184189 """Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set"""
@@ -192,7 +197,7 @@ def test_cross_encoder_reranker_with_chunk_ids(
192197 config .include = {QueryInclude .chunk }
193198 reranker = CrossEncoderReranker (config )
194199
195- result = reranker .rerank (query_result )
200+ result = await reranker .rerank (query_result )
196201
197202 mock_model .predict .assert_called ()
198203 assert isinstance (result , list )
@@ -231,7 +236,7 @@ def test_add_reranker_success():
231236
232237 @add_reranker
233238 class TestReranker (RerankerBase ):
234- def compute_similarity (self , results , query_message ):
239+ async def compute_similarity (self , results , query_message ):
235240 return []
236241
237242 assert len (get_available_rerankers ()) == original_count + 1
@@ -249,7 +254,7 @@ def test_add_reranker_duplicate():
249254 # First registration should succeed
250255 @add_reranker
251256 class TestReranker (RerankerBase ):
252- def rerank (self , results , query_chunks ):
257+ async def compute_similarity (self , results , query_message ):
253258 return []
254259
255260 # Second registration should fail
0 commit comments