Skip to content

Commit 20cdaa9

Browse files
authored
feat(cli): extensible reranker. Implements #68 (#85)
* feat(cli): Stop hardcode the reranker of choice. Implements #68 * improve fallback on faulty configs. * update documentation for reranker. * dynamically dispatch reranker class * add a decorator for registering custom rerankers for use as a library. * add `get_available_rerankers` function * add reranker class __doc__ to the error message when something's wrong * fix broken tests (for github actions) * some cleanup * `RerankerBase.rerank` doesn't need the query messages as a parameter. * simplify the process of building a new reranker by creating a `compute_similarity` method * make `rerank` and `compute_similarity` async.
1 parent 314cb01 commit 20cdaa9

11 files changed

Lines changed: 473 additions & 252 deletions

File tree

docs/cli.md

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,14 +248,25 @@ The JSON configuration file may hold the following values:
248248
guarantees the return of `n` documents, but with the risk of including too
249249
many less-relevant chunks that may affect the document selection. Default:
250250
`-1` (any negative value means selecting documents based on all indexed chunks);
251-
- `reranker`: string, a reranking model supported by
252-
[`CrossEncoder`](https://sbert.net/docs/package_reference/cross_encoder/index.html).
253-
A list of available models is available on their documentation. The default
254-
model is `"cross-encoder/ms-marco-MiniLM-L-6-v2"`. You can disable the use of
255-
`CrossEncoder` by setting this option to a falsy value that is not `null`,
256-
such as `false` or `""` (empty string);
251+
- `reranker`: string, the reranking method to use. Currently supports
252+
`CrossEncoderReranker` (default, using
253+
[sentence-transformers cross-encoder](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html)
254+
) and `NaiveReranker` (sort chunks by the "distance" between the embedding
255+
vectors);
257256
- `reranker_params`: dictionary, similar to `embedding_params`. The options
258-
passed to `CrossEncoder` class constructor;
257+
passed to the reranker class constructor. For `CrossEncoderReranker`, these
258+
are the options passed to the
259+
[`CrossEncoder`](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html#id1)
260+
class. For example, if you want to use a non-default model, you can use the
261+
following:
262+
```json
263+
{
264+
"reranker_params": {
265+
"model_name_or_path": "your_model_here"
266+
}
267+
}
268+
```
269+
;
259270
- `db_settings`: dictionary, works in a similar way to `embedding_params`, but
260271
for Chromadb client settings so that you can configure
261272
[authentication for remote Chromadb](https://docs.trychroma.com/production/administration/auth);

src/vectorcode/cli_utils.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ class Config:
8585
overlap_ratio: float = 0.2
8686
query_multiplier: int = -1
8787
query_exclude: list[PathLike] = field(default_factory=list)
88-
reranker: Optional[str] = "cross-encoder/ms-marco-MiniLM-L-6-v2"
89-
reranker_params: dict[str, Any] = field(default_factory=dict)
88+
reranker: Optional[str] = "CrossEncoderReranker"
89+
reranker_params: dict[str, Any] = field(default_factory=lambda: {})
9090
check_item: Optional[str] = None
9191
use_absolute_path: bool = False
9292
include: list[QueryInclude] = field(
@@ -100,6 +100,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
100100
"""
101101
Raise IOError if db_path is not valid.
102102
"""
103+
default_config = Config()
103104
db_path = config_dict.get("db_path")
104105
host = config_dict.get("host") or "localhost"
105106
port = config_dict.get("port") or 8000
@@ -112,25 +113,35 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
112113
return Config(
113114
**{
114115
"embedding_function": config_dict.get(
115-
"embedding_function", "SentenceTransformerEmbeddingFunction"
116+
"embedding_function", default_config.embedding_function
117+
),
118+
"embedding_params": config_dict.get(
119+
"embedding_params", default_config.embedding_params
116120
),
117-
"embedding_params": config_dict.get("embedding_params", {}),
118121
"host": host,
119122
"port": port,
120123
"db_path": db_path,
121124
"db_log_path": os.path.expanduser(
122-
config_dict.get("db_log_path", "~/.local/share/vectorcode/")
125+
config_dict.get("db_log_path", default_config.db_log_path)
126+
),
127+
"chunk_size": config_dict.get("chunk_size", default_config.chunk_size),
128+
"overlap_ratio": config_dict.get(
129+
"overlap_ratio", default_config.overlap_ratio
130+
),
131+
"query_multiplier": config_dict.get(
132+
"query_multiplier", default_config.query_multiplier
133+
),
134+
"reranker": config_dict.get("reranker", default_config.reranker),
135+
"reranker_params": config_dict.get(
136+
"reranker_params", default_config.reranker_params
137+
),
138+
"db_settings": config_dict.get(
139+
"db_settings", default_config.db_settings
123140
),
124-
"chunk_size": config_dict.get("chunk_size", 2500),
125-
"overlap_ratio": config_dict.get("overlap_ratio", 0.2),
126-
"query_multiplier": config_dict.get("query_multiplier", -1),
127-
"reranker": config_dict.get(
128-
"reranker", "cross-encoder/ms-marco-MiniLM-L-6-v2"
141+
"hnsw": config_dict.get("hnsw", default_config.hnsw),
142+
"chunk_filters": config_dict.get(
143+
"chunk_filters", default_config.chunk_filters
129144
),
130-
"reranker_params": config_dict.get("reranker_params", {}),
131-
"db_settings": config_dict.get("db_settings", None),
132-
"hnsw": config_dict.get("hnsw", {}),
133-
"chunk_filters": config_dict.get("chunk_filters", {}),
134145
}
135146
)
136147

src/vectorcode/subcommands/query/__init__.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_collection,
1515
verify_ef,
1616
)
17+
from vectorcode.subcommands.query.reranker import get_reranker
1718

1819
logger = logging.getLogger(name=__name__)
1920

@@ -66,17 +67,8 @@ async def get_query_result_files(
6667
# no results found
6768
return []
6869

69-
if not configs.reranker:
70-
from .reranker import NaiveReranker
71-
72-
aggregated_results = NaiveReranker(configs).rerank(results)
73-
else:
74-
from .reranker import CrossEncoderReranker
75-
76-
aggregated_results = CrossEncoderReranker(
77-
configs, query_chunks, configs.reranker, **configs.reranker_params
78-
).rerank(results)
79-
return aggregated_results
70+
reranker = get_reranker(configs)
71+
return await reranker.rerank(results)
8072

8173

8274
async def build_query_results(

src/vectorcode/subcommands/query/reranker.py

Lines changed: 0 additions & 95 deletions
This file was deleted.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import logging
2+
import sys
3+
from typing import Type
4+
5+
from vectorcode.cli_utils import Config
6+
7+
from .base import RerankerBase
8+
from .cross_encoder import CrossEncoderReranker
9+
from .naive import NaiveReranker
10+
11+
__all__ = ["RerankerBase", "NaiveReranker", "CrossEncoderReranker"]
12+
13+
logger = logging.getLogger(name=__name__)
14+
15+
__supported_rerankers: dict[str, Type[RerankerBase]] = {
16+
"CrossEncoderReranker": CrossEncoderReranker,
17+
"NaiveReranker": NaiveReranker,
18+
}
19+
20+
21+
def add_reranker(cls):
22+
"""
23+
This is a class decorator that allows you to add a custom reranker that can be
24+
recognised by the `get_reranker` function.
25+
26+
Your reranker should inherit `RerankerBase` and be decorated by `add_reranker`:
27+
```python
28+
@add_reranker
29+
class CustomReranker(RerankerBase):
30+
# override the methods according to your need.
31+
```
32+
"""
33+
if issubclass(cls, RerankerBase):
34+
if __supported_rerankers.get(cls.__name__):
35+
error_message = f"{cls.__name__} has been registered."
36+
logger.error(error_message)
37+
raise AttributeError(error_message)
38+
__supported_rerankers[cls.__name__] = cls
39+
return cls
40+
else:
41+
error_message = f'{cls} should be a subclass of "RerankerBase"'
42+
logger.error(error_message)
43+
raise TypeError(error_message)
44+
45+
46+
def get_available_rerankers():
47+
return list(__supported_rerankers.values())
48+
49+
50+
def get_reranker(configs: Config) -> RerankerBase:
51+
if configs.reranker:
52+
if hasattr(sys.modules[__name__], configs.reranker):
53+
# dynamic dispatch for built-in rerankers
54+
return getattr(sys.modules[__name__], configs.reranker).create(configs)
55+
56+
elif issubclass(
57+
__supported_rerankers.get(configs.reranker, type(None)), RerankerBase
58+
):
59+
return __supported_rerankers[configs.reranker].create(configs)
60+
61+
# TODO: replace the following with an Exception before the release of 0.6.0.
62+
logger.warning(
63+
f""""reranker" option should be set to one of the following: {list(i.__name__ for i in get_available_rerankers())}.
64+
To choose a CrossEncoderReranker model, you can set the "model_name_or_path" key in the "reranker_params" option to the name/path of the model.
65+
To use NaiveReranker, set the "reranker" option to "NaiveReranker".
66+
The old configuration syntax will be DEPRECATED in v0.6.0
67+
"""
68+
)
69+
if not configs.reranker:
70+
return NaiveReranker(configs)
71+
else:
72+
configs.reranker_params.update({"model_name_or_path": configs.reranker})
73+
configs.reranker = "CrossEncoderReranker"
74+
return CrossEncoderReranker(
75+
configs,
76+
)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import heapq
2+
import logging
3+
from abc import ABC, abstractmethod
4+
from collections import defaultdict
5+
from typing import Any, DefaultDict, Optional, Sequence, cast
6+
7+
import numpy
8+
from chromadb.api.types import QueryResult
9+
10+
from vectorcode.cli_utils import Config, QueryInclude
11+
12+
logger = logging.getLogger(name=__name__)
13+
14+
15+
class RerankerBase(ABC):
16+
"""This is the base class for the rerankers.
17+
You should use the configs.reranker_params field to store and pass the parameters used for your reranker.
18+
You should implement the `compute_similarity` method, which will be called by `rerank` to compute
19+
similarity scores between search query and results.
20+
The items in the returned list should be sorted such that the relevance decreases along the list.
21+
22+
The class doc string will be added to the error message if your reranker fails to initialise.
23+
Thus, this is a good place to put the instructions to configuring your reranker.
24+
"""
25+
26+
def __init__(self, configs: Config, **kwargs: Any):
27+
self.configs = configs
28+
assert self.configs.query is not None, (
29+
"'configs' should contain the query messages."
30+
)
31+
self.n_result = configs.n_result
32+
self._raw_results: Optional[QueryResult] = None
33+
34+
@classmethod
35+
def create(cls, configs: Config, **kwargs: Any):
36+
try:
37+
return cls(configs, **kwargs)
38+
except Exception as e:
39+
e.add_note(
40+
"\n"
41+
+ (
42+
cls.__doc__
43+
or f"There was an issue initialising {cls}. Please doublecheck your configuration."
44+
)
45+
)
46+
raise
47+
48+
@abstractmethod
49+
async def compute_similarity(
50+
self, results: list[str], query_message: str
51+
) -> Sequence[float]: # pragma: nocover
52+
"""Given a list of n results and 1 query message,
53+
return a list-like object of length n that contains the similarity scores between
54+
each item in `results` and the `query_message`.
55+
56+
A high similarity score means the strings are semantically similar to each other.
57+
`query_message` will be loaded in the same order as they appear in `self.configs.query`.
58+
59+
If you need the raw query results from chromadb,
60+
it'll be saved in `self._raw_results` before this method is called.
61+
"""
62+
raise NotImplementedError
63+
64+
async def rerank(self, results: QueryResult | dict) -> list[str]:
65+
self._raw_results = cast(QueryResult, results)
66+
query_chunks = self.configs.query
67+
assert query_chunks
68+
assert results["metadatas"] is not None
69+
assert results["documents"] is not None
70+
documents: DefaultDict[str, list[float]] = defaultdict(list)
71+
for query_chunk_idx in range(len(query_chunks)):
72+
chunk_ids = results["ids"][query_chunk_idx]
73+
chunk_metas = results["metadatas"][query_chunk_idx]
74+
chunk_docs = results["documents"][query_chunk_idx]
75+
scores = await self.compute_similarity(
76+
chunk_docs, query_chunks[query_chunk_idx]
77+
)
78+
for i, score in enumerate(scores):
79+
if QueryInclude.chunk in self.configs.include:
80+
documents[chunk_ids[i]].append(float(score))
81+
else:
82+
documents[str(chunk_metas[i]["path"])].append(float(score))
83+
84+
logger.debug("Document scores: %s", documents)
85+
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
86+
for key in documents.keys():
87+
documents[key] = heapq.nlargest(top_k, documents[key])
88+
89+
self._raw_results = None
90+
91+
return heapq.nlargest(
92+
self.n_result,
93+
documents.keys(),
94+
key=lambda x: float(numpy.mean(documents[x])),
95+
)

0 commit comments

Comments
 (0)