Skip to content

Commit d032c0a

Browse files
author
Zhe Yu
committed
feat(cli): Stop hardcode the reranker of choice. Implements #68
1 parent 314cb01 commit d032c0a

10 files changed

Lines changed: 247 additions & 153 deletions

File tree

src/vectorcode/cli_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ class Config:
8585
overlap_ratio: float = 0.2
8686
query_multiplier: int = -1
8787
query_exclude: list[PathLike] = field(default_factory=list)
88-
reranker: Optional[str] = "cross-encoder/ms-marco-MiniLM-L-6-v2"
89-
reranker_params: dict[str, Any] = field(default_factory=dict)
88+
reranker: Optional[str] = "CrossEncoderReranker"
89+
reranker_params: dict[str, Any] = field(
90+
default_factory=lambda: {
91+
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"
92+
}
93+
)
9094
check_item: Optional[str] = None
9195
use_absolute_path: bool = False
9296
include: list[QueryInclude] = field(
@@ -100,6 +104,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
100104
"""
101105
Raise IOError if db_path is not valid.
102106
"""
107+
default_config = Config()
103108
db_path = config_dict.get("db_path")
104109
host = config_dict.get("host") or "localhost"
105110
port = config_dict.get("port") or 8000
@@ -112,25 +117,35 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
112117
return Config(
113118
**{
114119
"embedding_function": config_dict.get(
115-
"embedding_function", "SentenceTransformerEmbeddingFunction"
120+
"embedding_function", default_config.embedding_function
121+
),
122+
"embedding_params": config_dict.get(
123+
"embedding_params", default_config.embedding_params
116124
),
117-
"embedding_params": config_dict.get("embedding_params", {}),
118125
"host": host,
119126
"port": port,
120127
"db_path": db_path,
121128
"db_log_path": os.path.expanduser(
122-
config_dict.get("db_log_path", "~/.local/share/vectorcode/")
129+
config_dict.get("db_log_path", default_config.db_log_path)
130+
),
131+
"chunk_size": config_dict.get("chunk_size", default_config.chunk_size),
132+
"overlap_ratio": config_dict.get(
133+
"overlap_ratio", default_config.overlap_ratio
134+
),
135+
"query_multiplier": config_dict.get(
136+
"query_multiplier", default_config.query_multiplier
137+
),
138+
"reranker": config_dict.get("reranker", default_config.reranker),
139+
"reranker_params": config_dict.get(
140+
"reranker_params", default_config.reranker_params
141+
),
142+
"db_settings": config_dict.get(
143+
"db_settings", default_config.db_settings
123144
),
124-
"chunk_size": config_dict.get("chunk_size", 2500),
125-
"overlap_ratio": config_dict.get("overlap_ratio", 0.2),
126-
"query_multiplier": config_dict.get("query_multiplier", -1),
127-
"reranker": config_dict.get(
128-
"reranker", "cross-encoder/ms-marco-MiniLM-L-6-v2"
145+
"hnsw": config_dict.get("hnsw", default_config.hnsw),
146+
"chunk_filters": config_dict.get(
147+
"chunk_filters", default_config.chunk_filters
129148
),
130-
"reranker_params": config_dict.get("reranker_params", {}),
131-
"db_settings": config_dict.get("db_settings", None),
132-
"hnsw": config_dict.get("hnsw", {}),
133-
"chunk_filters": config_dict.get("chunk_filters", {}),
134149
}
135150
)
136151

src/vectorcode/subcommands/query/__init__.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_collection,
1515
verify_ef,
1616
)
17+
from vectorcode.subcommands.query.reranker import get_reranker
1718

1819
logger = logging.getLogger(name=__name__)
1920

@@ -66,17 +67,8 @@ async def get_query_result_files(
6667
# no results found
6768
return []
6869

69-
if not configs.reranker:
70-
from .reranker import NaiveReranker
71-
72-
aggregated_results = NaiveReranker(configs).rerank(results)
73-
else:
74-
from .reranker import CrossEncoderReranker
75-
76-
aggregated_results = CrossEncoderReranker(
77-
configs, query_chunks, configs.reranker, **configs.reranker_params
78-
).rerank(results)
79-
return aggregated_results
70+
reranker = get_reranker(configs)
71+
return reranker.rerank(results, query_chunks)
8072

8173

8274
async def build_query_results(

src/vectorcode/subcommands/query/reranker.py

Lines changed: 0 additions & 95 deletions
This file was deleted.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import logging
2+
3+
from vectorcode.cli_utils import Config
4+
5+
from .base import RerankerBase
6+
from .cross_encoder import CrossEncoderReranker
7+
from .naive import NaiveReranker
8+
9+
__all__ = ["RerankerBase", "NaiveReranker", "CrossEncoderReranker"]
10+
11+
logger = logging.getLogger(name=__name__)
12+
13+
14+
def get_reranker(configs: Config) -> RerankerBase:
15+
if configs.reranker == "NaiveReranker" or not configs.reranker:
16+
return NaiveReranker(configs)
17+
elif configs.reranker == "CrossEncoderReranker":
18+
return CrossEncoderReranker(configs)
19+
else:
20+
logger.warning(
21+
f"""
22+
"reranker" option should be set to one of the following: {list(i for i in __all__ if i != "RerankerBase")}.
23+
To choose a custom reranker model, you can set the "model_name_or_path" key in the "reranker_params" option to the name/path of the model.
24+
The old configuration syntax will be deprecated in v0.6.0
25+
"""
26+
)
27+
configs.reranker_params.update({"model_name_or_path": configs.reranker})
28+
configs.reranker = "CrossEncoderReranker"
29+
return CrossEncoderReranker(
30+
configs,
31+
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import logging
2+
from abc import abstractmethod
3+
from typing import Any
4+
5+
from chromadb.api.types import QueryResult
6+
7+
from vectorcode.cli_utils import Config
8+
9+
logger = logging.getLogger(name=__name__)
10+
11+
12+
class RerankerBase:
13+
def __init__(self, configs: Config, **kwargs: Any):
14+
self.configs = configs
15+
self.n_result = configs.n_result
16+
17+
@abstractmethod
18+
def rerank(self, results: QueryResult, query_chunks: list[str]) -> list[str]:
19+
raise NotImplementedError
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import heapq
2+
import logging
3+
from collections import defaultdict
4+
from typing import Any, DefaultDict
5+
6+
import numpy
7+
8+
from vectorcode.cli_utils import Config, QueryInclude
9+
10+
from .base import RerankerBase
11+
12+
logger = logging.getLogger(name=__name__)
13+
14+
15+
class CrossEncoderReranker(RerankerBase):
16+
def __init__(
17+
self,
18+
configs: Config,
19+
**kwargs: Any,
20+
):
21+
super().__init__(configs)
22+
from sentence_transformers import CrossEncoder
23+
24+
self.model = CrossEncoder(**configs.reranker_params)
25+
26+
def rerank(self, results, query_chunks) -> list[str]:
27+
self.query_chunks = query_chunks
28+
assert results["metadatas"] is not None
29+
assert results["documents"] is not None
30+
documents: DefaultDict[str, list[float]] = defaultdict(list)
31+
for query_chunk_idx in range(len(self.query_chunks)):
32+
chunk_ids = results["ids"][query_chunk_idx]
33+
chunk_metas = results["metadatas"][query_chunk_idx]
34+
chunk_docs = results["documents"][query_chunk_idx]
35+
ranks = self.model.rank(
36+
self.query_chunks[query_chunk_idx], chunk_docs, apply_softmax=True
37+
)
38+
for rank in ranks:
39+
if QueryInclude.chunk in self.configs.include:
40+
documents[chunk_ids[rank["corpus_id"]]].append(float(rank["score"]))
41+
else:
42+
documents[chunk_metas[rank["corpus_id"]]["path"]].append(
43+
float(rank["score"])
44+
)
45+
logger.debug("Document scores: %s", documents)
46+
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
47+
for key in documents.keys():
48+
documents[key] = heapq.nlargest(top_k, documents[key])
49+
50+
return heapq.nlargest(
51+
self.n_result,
52+
documents.keys(),
53+
key=lambda x: float(numpy.mean(documents[x])),
54+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import heapq
2+
import logging
3+
from collections import defaultdict
4+
from typing import Any, DefaultDict
5+
6+
import numpy
7+
8+
from vectorcode.cli_utils import Config, QueryInclude
9+
10+
from .base import RerankerBase
11+
12+
logger = logging.getLogger(name=__name__)
13+
14+
15+
class NaiveReranker(RerankerBase):
16+
def __init__(self, configs: Config, **kwargs: Any):
17+
super().__init__(configs)
18+
19+
def rerank(self, results, query_chunks) -> list[str]:
20+
assert results["metadatas"] is not None
21+
assert results["distances"] is not None
22+
documents: DefaultDict[str, list[float]] = defaultdict(list)
23+
for query_chunk_idx in range(len(results["ids"])):
24+
chunk_ids = results["ids"][query_chunk_idx]
25+
chunk_metas = results["metadatas"][query_chunk_idx]
26+
chunk_distances = results["distances"][query_chunk_idx]
27+
# NOTE: distances, smaller is better.
28+
paths = [str(meta["path"]) for meta in chunk_metas]
29+
assert len(paths) == len(chunk_distances)
30+
for distance, identifier in zip(
31+
chunk_distances,
32+
chunk_ids if QueryInclude.chunk in self.configs.include else paths,
33+
):
34+
if identifier is None: # pragma: nocover
35+
# so that vectorcode doesn't break on old collections.
36+
continue
37+
documents[identifier].append(distance)
38+
logger.debug("Document scores: %s", documents)
39+
top_k = int(numpy.mean(tuple(len(i) for i in documents.values())))
40+
for key in documents.keys():
41+
documents[key] = heapq.nsmallest(top_k, documents[key])
42+
43+
return heapq.nsmallest(
44+
self.n_result, documents.keys(), lambda x: float(numpy.mean(documents[x]))
45+
)

tests/subcommands/query/test_query.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ def mock_config():
6262
@pytest.mark.asyncio
6363
async def test_get_query_result_files(mock_collection, mock_config):
6464
# Mock the reranker
65-
with patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker:
65+
with patch("vectorcode.subcommands.query.get_reranker") as mock_get_reranker:
6666
mock_reranker_instance = MagicMock()
6767
mock_reranker_instance.rerank.return_value = [
6868
"file1.py",
6969
"file2.py",
7070
"file3.py",
7171
]
72-
MockReranker.return_value = mock_reranker_instance
72+
mock_get_reranker.return_value = mock_reranker_instance
7373

7474
# Call the function
7575
result = await get_query_result_files(mock_collection, mock_config)
@@ -87,9 +87,9 @@ async def test_get_query_result_files(mock_collection, mock_config):
8787
assert not kwargs["where"] # Since query_exclude is empty
8888

8989
# Check reranker was used correctly
90-
MockReranker.assert_called_once_with(mock_config)
90+
mock_get_reranker.assert_called_once_with(mock_config)
9191
mock_reranker_instance.rerank.assert_called_once_with(
92-
mock_collection.query.return_value
92+
mock_collection.query.return_value, mock_config.query
9393
)
9494

9595
# Check the result

0 commit comments

Comments
 (0)