Skip to content

Commit 0d6e338

Browse files
author
Zhe Yu
committed
rename parameter
1 parent 27f96ef commit 0d6e338

3 files changed

Lines changed: 22 additions & 5 deletions

File tree

src/vectorcode/subcommands/query/reranker/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def rerank(self, results: list[QueryResult]) -> list[str]:
6262
group_by = "path"
6363
if QueryInclude.chunk in self.configs.include:
6464
group_by = "chunk"
65-
grouped_results = QueryResult.group(*results, key=group_by, top_k="auto")
65+
grouped_results = QueryResult.group(*results, by=group_by, top_k="auto")
6666

6767
scores: dict[Chunk | str, float] = {}
6868
for key in grouped_results.keys():

src/vectorcode/subcommands/query/types.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,25 @@ def merge(cls, *results: "QueryResult") -> "QueryResult":
4646
@staticmethod
4747
def group(
4848
*results: "QueryResult",
49-
key: Union[Literal["path"], Literal["chunk"]] = "path",
49+
by: Union[Literal["path"], Literal["chunk"]] = "path",
5050
top_k: int | Literal["auto"] | None = None,
5151
) -> dict[Chunk | str, list["QueryResult"]]:
52-
assert key in {"path", "chunk"}
52+
"""
53+
Group the query results based on `key`.
54+
55+
args:
56+
- `by`: either "path" or "chunk"
57+
- `top_k`: if set, only return the top k results for each group based on mean scores. If "auto", top k is decided by the mean number of results per group.
58+
59+
returns:
60+
- a dictionary that maps either path or chunk to a list of `QueryResult` object.
61+
62+
"""
63+
assert by in {"path", "chunk"}
5364
grouped_result: dict[Chunk | str, list["QueryResult"]] = defaultdict(list)
5465

5566
for res in results:
56-
grouped_result[getattr(res, key)].append(res)
67+
grouped_result[getattr(res, by)].append(res)
5768

5869
if top_k == "auto":
5970
top_k = int(numpy.mean(tuple(len(i) for i in grouped_result.values())))
@@ -67,10 +78,16 @@ def mean_score(self):
6778
return float(numpy.mean(self.scores))
6879

6980
def __lt__(self, other: "QueryResult"):
81+
assert isinstance(other, QueryResult)
7082
return self.mean_score() < other.mean_score()
7183

7284
def __gt__(self, other: "QueryResult"):
85+
assert isinstance(other, QueryResult)
7386
return self.mean_score() > other.mean_score()
7487

88+
def __eq__(self, other: object, /) -> bool:
89+
assert isinstance(other, QueryResult)
90+
return self.mean_score() == other.mean_score()
91+
7592
def is_same_doc(self, other: "QueryResult") -> bool:
7693
return self.path == other.path and self.chunk == other.chunk

tests/subcommands/query/test_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_QueryResult_group_by_chunk():
5353
res2.query = ["bye"]
5454
res2.scores = [0.1]
5555

56-
grouped_dict = QueryResult.group(res1, res2, key="chunk")
56+
grouped_dict = QueryResult.group(res1, res2, by="chunk")
5757
assert len(grouped_dict.keys()) == 1
5858
assert len(grouped_dict[res1.chunk]) == 2
5959

0 commit comments

Comments
 (0)