Skip to content

Commit c84ff39

Browse files
author
Zhe Yu
committed
refactor(cli): make function signatures more consistent.
1 parent 0d6e338 commit c84ff39

6 files changed

Lines changed: 74 additions & 87 deletions

File tree

src/vectorcode/chunking.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ class Chunk:
2727
text: str
2828
start: Point | None = None
2929
end: Point | None = None
30+
path: str | None = None
31+
id: str | None = None
3032

3133
def __str__(self):
3234
return self.text
3335

3436
def __hash__(self) -> int:
35-
return hash(f"VectorCodeChunk({self.start}:{self.end}@{self.text})")
37+
return hash(f"VectorCodeChunk_{self.path}({self.start}:{self.end}@{self.text})")
3638

3739
def export_dict(self):
3840
d: dict[str, str | dict[str, int]] = {"text": self.text}
@@ -48,6 +50,10 @@ def export_dict(self):
4850
"end": {"row": self.end.row, "column": self.end.column},
4951
}
5052
)
53+
if self.path is not None:
54+
d["path"] = self.path
55+
if self.id:
56+
d["chunk_id"] = self.id
5157
return d
5258

5359

src/vectorcode/subcommands/query/__init__.py

Lines changed: 43 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
from typing import Any, cast
55

6-
from chromadb import GetResult, Where
6+
from chromadb import Where
77
from chromadb.api.models.AsyncCollection import AsyncCollection
88
from chromadb.api.types import IncludeEnum, QueryResult
99
from chromadb.errors import InvalidCollectionException, InvalidDimensionException
@@ -39,19 +39,23 @@ def convert_query_results(
3939
assert chroma_result["documents"] is not None
4040
assert chroma_result["distances"] is not None
4141
assert chroma_result["metadatas"] is not None
42+
assert chroma_result["ids"] is not None
4243

4344
chroma_results_list: list[vectorcode_types.QueryResult] = []
4445
for q_i in range(len(queries)):
4546
q = queries[q_i]
4647
documents = chroma_result["documents"][q_i]
4748
distances = chroma_result["distances"][q_i]
4849
metadatas = chroma_result["metadatas"][q_i]
49-
for doc, dist, meta in zip(documents, distances, metadatas):
50-
chunk = Chunk(text=doc)
50+
ids = chroma_result["ids"][q_i]
51+
for doc, dist, meta, _id in zip(documents, distances, metadatas, ids):
52+
chunk = Chunk(text=doc, id=_id)
5153
if meta.get("start"):
5254
chunk.start = Point(int(meta.get("start", 0)), 0)
5355
if meta.get("end"):
54-
chunk.end = Point(int(meta.get("end", 0)) + 1, 0)
56+
chunk.end = Point(int(meta.get("end", 0)), 0)
57+
if meta.get("path"):
58+
chunk.path = str(meta["path"])
5559
chroma_results_list.append(
5660
vectorcode_types.QueryResult(
5761
chunk=chunk,
@@ -65,7 +69,7 @@ def convert_query_results(
6569

6670
async def get_query_result_files(
6771
collection: AsyncCollection, configs: Config
68-
) -> list[str]:
72+
) -> list[str | Chunk]:
6973
query_chunks = []
7074
assert configs.query, "Query messages cannot be empty."
7175
chunker = StringChunker(configs)
@@ -126,63 +130,43 @@ async def get_query_result_files(
126130
async def build_query_results(
127131
collection: AsyncCollection, configs: Config
128132
) -> list[dict[str, str | int]]:
129-
structured_result = []
130-
for identifier in await get_query_result_files(collection, configs):
131-
if os.path.isfile(identifier):
132-
if configs.use_absolute_path:
133-
output_path = os.path.abspath(identifier)
134-
else:
135-
output_path = os.path.relpath(identifier, configs.project_root)
136-
full_result = {"path": output_path}
137-
with open(identifier) as fin:
138-
document = fin.read()
139-
full_result["document"] = document
133+
assert configs.project_root
140134

141-
structured_result.append(
142-
{str(key): full_result[str(key)] for key in configs.include}
143-
)
144-
elif QueryInclude.chunk in configs.include:
145-
chunks: GetResult = await collection.get(
146-
identifier, include=[IncludeEnum.metadatas, IncludeEnum.documents]
147-
)
148-
meta = chunks.get(
149-
"metadatas",
150-
)
151-
if meta is not None and len(meta) != 0:
152-
chunk_texts = chunks.get("documents")
153-
assert chunk_texts is not None, (
154-
"QueryResult does not contain `documents`!"
155-
)
156-
full_result: dict[str, str | int] = {
157-
"chunk": str(chunk_texts[0]),
158-
"chunk_id": identifier,
159-
}
160-
if meta[0].get("start") is not None and meta[0].get("end") is not None:
161-
path = str(meta[0].get("path"))
162-
with open(path) as fin:
163-
start: int = int(meta[0]["start"])
164-
end: int = int(meta[0]["end"])
165-
full_result["chunk"] = "".join(fin.readlines()[start : end + 1])
166-
full_result["start_line"] = start
167-
full_result["end_line"] = end
168-
if QueryInclude.path in configs.include:
169-
full_result["path"] = str(
170-
meta[0]["path"]
171-
if configs.use_absolute_path
172-
else os.path.relpath(
173-
str(meta[0]["path"]), str(configs.project_root)
174-
)
175-
)
176-
177-
structured_result.append(full_result)
178-
else: # pragma: nocover
179-
logger.error(
180-
"This collection doesn't support chunk-mode output because it lacks the necessary metadata. Please re-vectorise it.",
181-
)
135+
def make_output_path(path: str, absolute: bool) -> str:
136+
if absolute:
137+
if os.path.isabs(path):
138+
return path
139+
return os.path.abspath(os.path.join(str(configs.project_root), path))
140+
else:
141+
rel_path = os.path.relpath(path, configs.project_root)
142+
if isinstance(rel_path, bytes): # pragma: nocover
143+
# for some reasons some python versions report that `os.path.relpath` returns a string.
144+
rel_path = rel_path.decode()
145+
return rel_path
182146

147+
structured_result = []
148+
for res in await get_query_result_files(collection, configs):
149+
if isinstance(res, str):
150+
output_path = make_output_path(res, configs.use_absolute_path)
151+
io_path = make_output_path(res, True)
152+
if not os.path.isfile(io_path):
153+
logger.warning(f"{io_path} is no longer a valid file.")
154+
continue
155+
with open(io_path) as fin:
156+
structured_result.append({"path": output_path, "document": fin.read()})
183157
else:
184-
logger.warning(
185-
f"{identifier} is no longer a valid file! Please re-run vectorcode vectorise to refresh the database.",
158+
res = cast(Chunk, res)
159+
assert res.path, f"{res} has no `path` attribute."
160+
structured_result.append(
161+
{
162+
"path": make_output_path(res.path, configs.use_absolute_path)
163+
if res.path is not None
164+
else None,
165+
"chunk": res.text,
166+
"start_line": res.start.row if res.start is not None else None,
167+
"end_line": res.end.row if res.end is not None else None,
168+
"chunk_id": res.id,
169+
}
186170
)
187171
for result in structured_result:
188172
if result.get("path") is not None:

src/vectorcode/subcommands/query/reranker/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,28 +50,33 @@ async def compute_similarity(
5050
self, results: list[QueryResult]
5151
) -> None: # pragma: nocover
5252
"""
53-
Modify the `QueryResult.scores` field IN-PLACE so that they contain the correct scores.
53+
Modify the `QueryResult.scores` field **IN-PLACE** so that they contain the correct scores.
5454
"""
5555
raise NotImplementedError
5656

57-
async def rerank(self, results: list[QueryResult]) -> list[str]:
57+
async def rerank(self, results: list[QueryResult]) -> list[str | Chunk]:
5858
if len(results) == 0:
5959
return []
60+
61+
# compute the similarity scores
6062
await self.compute_similarity(results)
6163

64+
# group the results by the query type: file (path) or chunk
65+
# and only keep the `top_k` results for each group
6266
group_by = "path"
6367
if QueryInclude.chunk in self.configs.include:
6468
group_by = "chunk"
6569
grouped_results = QueryResult.group(*results, by=group_by, top_k="auto")
6670

71+
# compute the mean scores for each of the groups
6772
scores: dict[Chunk | str, float] = {}
6873
for key in grouped_results.keys():
6974
scores[key] = float(
7075
numpy.mean(tuple(i.mean_score() for i in grouped_results[key]))
7176
)
7277

7378
return list(
74-
str(i)
79+
i
7580
for i in heapq.nlargest(
7681
self.configs.n_result, grouped_results.keys(), key=lambda x: scores[x]
7782
)

src/vectorcode/subcommands/query/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ def __gt__(self, other: "QueryResult"):
8686
return self.mean_score() > other.mean_score()
8787

8888
def __eq__(self, other: object, /) -> bool:
89-
assert isinstance(other, QueryResult)
90-
return self.mean_score() == other.mean_score()
89+
return (
90+
isinstance(other, QueryResult) and self.mean_score() == other.mean_score()
91+
)
9192

9293
def is_same_doc(self, other: "QueryResult") -> bool:
9394
return self.path == other.path and self.chunk == other.chunk

tests/subcommands/query/test_query.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
4-
from chromadb import GetResult
4+
from chromadb import QueryResult
55
from chromadb.api.models.AsyncCollection import AsyncCollection
66
from chromadb.api.types import IncludeEnum
77
from chromadb.errors import InvalidCollectionException, InvalidDimensionException
@@ -136,43 +136,34 @@ async def test_build_query_results_chunk_mode_success(mock_collection, mock_conf
136136
mock_config.include = [QueryInclude.chunk, QueryInclude.path]
137137
mock_config.project_root = "/test/project"
138138
mock_config.use_absolute_path = False
139-
identifier = "chunk_id_1"
139+
mock_config.query = ["dummy_query"]
140+
identifier = "chunk_id"
140141
file_path = "/test/project/subdir/file1.py"
141142
relative_path = "subdir/file1.py"
142143
start_line = 5
143144
end_line = 10
144145

145146
full_file_content_lines = [f"line {i}\n" for i in range(15)]
146-
full_file_content = "".join(full_file_content_lines)
147147

148148
expected_chunk_content = "".join(full_file_content_lines[start_line : end_line + 1])
149149

150-
mock_get_result = GetResult(
151-
ids=[identifier],
152-
embeddings=None,
153-
documents=["original chunk doc in db"],
154-
metadatas=[{"path": file_path, "start": start_line, "end": end_line}],
150+
mock_get_result = QueryResult(
151+
ids=[[identifier]],
152+
documents=[[expected_chunk_content]],
153+
metadatas=[[{"path": file_path, "start": start_line, "end": end_line}]],
154+
distances=[[0.2]],
155155
)
156-
156+
mock_collection.query = AsyncMock(return_value=mock_get_result)
157157
with (
158158
patch(
159159
"vectorcode.subcommands.query.get_query_result_files",
160-
return_value=[identifier],
160+
return_value=await get_query_result_files(mock_collection, mock_config),
161161
),
162162
patch("os.path.isfile", return_value=False),
163-
patch("builtins.open", mock_open(read_data=full_file_content)) as mocked_open,
164163
patch("os.path.relpath", return_value=relative_path) as mock_relpath,
165164
):
166-
mock_collection.get = AsyncMock(return_value=mock_get_result)
167-
168165
results = await build_query_results(mock_collection, mock_config)
169166

170-
mock_collection.get.assert_called_once_with(
171-
identifier, include=[IncludeEnum.metadatas, IncludeEnum.documents]
172-
)
173-
174-
mocked_open.assert_called_once_with(file_path)
175-
176167
mock_relpath.assert_called_once_with(file_path, str(mock_config.project_root))
177168

178169
assert len(results) == 1

tests/subcommands/query/test_reranker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,15 @@ async def test_naive_reranker_rerank_chunks(naive_reranker_conf, query_result):
134134
"""Test basic reranking functionality of NaiveReranker"""
135135
naive_reranker_conf.include = [QueryInclude.chunk]
136136
reranker = NaiveReranker(naive_reranker_conf)
137-
chunk_text = {str(i.chunk) for i in query_result}
137+
chunks = {i.chunk for i in query_result}
138138
result = await reranker.rerank(query_result)
139139

140140
# Check the result is a list of paths with correct length
141141
assert isinstance(result, list)
142142
assert len(result) <= naive_reranker_conf.n_result
143143

144144
for res in result:
145-
assert res in chunk_text
145+
assert res in chunks
146146

147147

148148
@pytest.mark.asyncio

0 commit comments

Comments
 (0)