Skip to content

Commit 225b23a

Browse files
author
Zhe Yu
committed
add a decorator for registering custom rerankers for use as a library.
1 parent d165188 commit 225b23a

2 files changed

Lines changed: 103 additions & 3 deletions

File tree

src/vectorcode/subcommands/query/reranker/__init__.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import sys
3+
from typing import Type
34

45
from vectorcode.cli_utils import Config
56

@@ -11,11 +12,47 @@
1112

1213
logger = logging.getLogger(name=__name__)
1314

15+
__supported_rerankers: dict[str, Type[RerankerBase]] = {
16+
"CrossEncoderReranker": CrossEncoderReranker,
17+
"NaiveReranker": NaiveReranker,
18+
}
19+
20+
21+
def add_reranker(cls):
22+
"""
23+
This is a class decorator that allows you to add a custom reranker that can be
24+
recognised by the `get_reranker` function.
25+
26+
Your reranker should inherit `RerankerBase` and be decorated by `add_reranker`:
27+
```python
28+
@add_reranker
29+
class CustomReranker(RerankerBase):
30+
# override the methods according to your need.
31+
```
32+
"""
33+
if issubclass(cls, RerankerBase):
34+
if __supported_rerankers.get(cls.__name__):
35+
error_message = f"{cls.__name__} has been registered."
36+
logger.error(error_message)
37+
raise AttributeError(error_message)
38+
__supported_rerankers[cls.__name__] = cls
39+
return cls
40+
else:
41+
error_message = f'{cls} should be a subclass of "RerankerBase"'
42+
logger.error(error_message)
43+
raise TypeError(error_message)
44+
1445

1546
def get_reranker(configs: Config) -> RerankerBase:
16-
if configs.reranker and hasattr(sys.modules[__name__], configs.reranker):
17-
# dynamic dispatch
18-
return getattr(sys.modules[__name__], configs.reranker)(configs)
47+
if configs.reranker:
48+
if hasattr(sys.modules[__name__], configs.reranker):
49+
# dynamic dispatch for built-in rerankers
50+
return getattr(sys.modules[__name__], configs.reranker)(configs)
51+
52+
elif issubclass(
53+
__supported_rerankers.get(configs.reranker, type(None)), RerankerBase
54+
):
55+
return __supported_rerankers[configs.reranker](configs)
1956

2057
# TODO: replace the following with an Exception before the release of 0.6.0.
2158
logger.warning(

tests/subcommands/query/test_reranker.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,66 @@ def test_get_reranker():
244244
reranker.configs.reranker_params.get("model_name_or_path")
245245
== "cross-encoder/ms-marco-MiniLM-L-6-v2"
246246
), "configs.reranker_params should fallback to default params."
247+
248+
249+
def test_supported_rerankers_initialization():
250+
"""Test that __supported_rerankers contains the expected default rerankers"""
251+
from vectorcode.subcommands.query.reranker import __supported_rerankers
252+
253+
assert "CrossEncoderReranker" in __supported_rerankers
254+
assert "NaiveReranker" in __supported_rerankers
255+
assert len(__supported_rerankers) == 2
256+
257+
258+
def test_add_reranker_success():
259+
"""Test successful registration of a new reranker"""
260+
from vectorcode.subcommands.query.reranker import (
261+
RerankerBase,
262+
__supported_rerankers,
263+
add_reranker,
264+
)
265+
266+
original_count = len(__supported_rerankers)
267+
268+
@add_reranker
269+
class TestReranker(RerankerBase):
270+
def rerank(self, results, query_chunks):
271+
return []
272+
273+
assert len(__supported_rerankers) == original_count + 1
274+
assert "TestReranker" in __supported_rerankers
275+
assert isinstance(get_reranker(Config(reranker="TestReranker")), TestReranker)
276+
__supported_rerankers.pop("TestReranker")
277+
278+
279+
def test_add_reranker_duplicate():
280+
"""Test duplicate reranker registration raises error"""
281+
from vectorcode.subcommands.query.reranker import (
282+
RerankerBase,
283+
__supported_rerankers,
284+
add_reranker,
285+
)
286+
287+
# First registration should succeed
288+
@add_reranker
289+
class TestReranker(RerankerBase):
290+
def rerank(self, results, query_chunks):
291+
return []
292+
293+
# Second registration should fail
294+
with pytest.raises(AttributeError):
295+
add_reranker(TestReranker)
296+
__supported_rerankers.pop("TestReranker")
297+
298+
299+
def test_add_reranker_invalid_baseclass():
300+
"""Test that non-RerankerBase classes can't be registered"""
301+
from vectorcode.subcommands.query.reranker import (
302+
add_reranker,
303+
)
304+
305+
with pytest.raises(TypeError):
306+
307+
@add_reranker
308+
class InvalidReranker:
309+
pass

0 commit comments

Comments
 (0)