Skip to content

Commit a5dcc89

Browse files
author
Zhe Yu
committed
add get_available_rerankers function
1 parent 225b23a commit a5dcc89

2 files changed

Lines changed: 15 additions & 20 deletions

File tree

src/vectorcode/subcommands/query/reranker/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class CustomReranker(RerankerBase):
4343
raise TypeError(error_message)
4444

4545

46+
def get_available_rerankers():
47+
return list(__supported_rerankers.values())
48+
49+
4650
def get_reranker(configs: Config) -> RerankerBase:
4751
if configs.reranker:
4852
if hasattr(sys.modules[__name__], configs.reranker):
@@ -56,7 +60,7 @@ def get_reranker(configs: Config) -> RerankerBase:
5660

5761
# TODO: replace the following with an Exception before the release of 0.6.0.
5862
logger.warning(
59-
f""""reranker" option should be set to one of the following: {list(i for i in __all__ if i != "RerankerBase")}.
63+
f""""reranker" option should be set to one of the following: {list(i.__name__ for i in get_available_rerankers())}.
6064
To choose a CrossEncoderReranker model, you can set the "model_name_or_path" key in the "reranker_params" option to the name/path of the model.
6165
To use NaiveReranker, set the "reranker" option to "NaiveReranker".
6266
The old configuration syntax will be DEPRECATED in v0.6.0

tests/subcommands/query/test_reranker.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
CrossEncoderReranker,
88
NaiveReranker,
99
RerankerBase,
10+
__supported_rerankers,
11+
add_reranker,
12+
get_available_rerankers,
1013
get_reranker,
1114
)
1215

@@ -248,41 +251,32 @@ def test_get_reranker():
248251

249252
def test_supported_rerankers_initialization():
250253
"""Test that __supported_rerankers contains the expected default rerankers"""
251-
from vectorcode.subcommands.query.reranker import __supported_rerankers
252254

253-
assert "CrossEncoderReranker" in __supported_rerankers
254-
assert "NaiveReranker" in __supported_rerankers
255-
assert len(__supported_rerankers) == 2
255+
assert isinstance(
256+
get_reranker(Config(reranker="CrossEncoderReranker")), CrossEncoderReranker
257+
)
258+
assert isinstance(get_reranker(Config(reranker="NaiveReranker")), NaiveReranker)
259+
assert len(get_available_rerankers()) == 2
256260

257261

258262
def test_add_reranker_success():
259263
"""Test successful registration of a new reranker"""
260-
from vectorcode.subcommands.query.reranker import (
261-
RerankerBase,
262-
__supported_rerankers,
263-
add_reranker,
264-
)
265264

266-
original_count = len(__supported_rerankers)
265+
original_count = len(get_available_rerankers())
267266

268267
@add_reranker
269268
class TestReranker(RerankerBase):
270269
def rerank(self, results, query_chunks):
271270
return []
272271

273-
assert len(__supported_rerankers) == original_count + 1
272+
assert len(get_available_rerankers()) == original_count + 1
274273
assert "TestReranker" in __supported_rerankers
275274
assert isinstance(get_reranker(Config(reranker="TestReranker")), TestReranker)
276275
__supported_rerankers.pop("TestReranker")
277276

278277

279278
def test_add_reranker_duplicate():
280279
"""Test duplicate reranker registration raises error"""
281-
from vectorcode.subcommands.query.reranker import (
282-
RerankerBase,
283-
__supported_rerankers,
284-
add_reranker,
285-
)
286280

287281
# First registration should succeed
288282
@add_reranker
@@ -298,9 +292,6 @@ def rerank(self, results, query_chunks):
298292

299293
def test_add_reranker_invalid_baseclass():
300294
"""Test that non-RerankerBase classes can't be registered"""
301-
from vectorcode.subcommands.query.reranker import (
302-
add_reranker,
303-
)
304295

305296
with pytest.raises(TypeError):
306297

0 commit comments

Comments
 (0)