Skip to content

Commit f2aadff

Browse files
author
Zhe Yu
committed
fix(cli): test coverage.
1 parent a5f1bb9 commit f2aadff

6 files changed

Lines changed: 81 additions & 155 deletions

File tree

src/vectorcode/mcp_main.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import sys
6+
import traceback
67
from dataclasses import dataclass
78
from pathlib import Path
89
from typing import Optional, cast
@@ -161,13 +162,16 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]
161162

162163
return stats.to_dict()
163164
except Exception as e:
164-
logger.error("Failed to access collection at %s", project_root)
165-
raise McpError(
166-
ErrorData(
167-
code=1,
168-
message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.",
169-
)
170-
)
165+
if isinstance(e, McpError):
166+
logger.error("Failed to access collection at %s", project_root)
167+
raise
168+
else:
169+
raise McpError(
170+
ErrorData(
171+
code=1,
172+
message="\n".join(traceback.format_exception(e)),
173+
)
174+
) from e
171175

172176

173177
async def query_tool(
@@ -222,13 +226,16 @@ async def query_tool(
222226
return results
223227

224228
except Exception as e:
225-
logger.error("Failed to access collection at %s", project_root)
226-
raise McpError(
227-
ErrorData(
228-
code=1,
229-
message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
230-
)
231-
)
229+
if isinstance(e, McpError):
230+
logger.error("Failed to access collection at %s", project_root)
231+
raise
232+
else:
233+
raise McpError(
234+
ErrorData(
235+
code=1,
236+
message="\n".join(traceback.format_exception(e)),
237+
)
238+
) from e
232239

233240

234241
async def ls_files(project_root: str) -> list[str]:

src/vectorcode/subcommands/query/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def conver_query_results(
4848
metadatas = chroma_result["metadatas"][q_i]
4949
for doc, dist, meta in zip(documents, distances, metadatas):
5050
chunk = Chunk(text=doc)
51-
if meta["start"]:
51+
if meta.get("start"):
5252
chunk.start = Point(int(meta.get("start", 0)), 0)
53-
if meta["end"]:
53+
if meta.get("end"):
5454
chunk.end = Point(int(meta.get("end", 0)) + 1, 0)
5555
chroma_results_list.append(
5656
vectorcode_types.QueryResult(

src/vectorcode/subcommands/query/reranker/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def create(cls, configs: Config, **kwargs: Any):
4646
raise
4747

4848
@abstractmethod
49-
async def compute_similarity(self, results: list[QueryResult]): # pragma: nocover
49+
async def compute_similarity(
50+
self, results: list[QueryResult]
51+
) -> None: # pragma: nocover
5052
"""
5153
Modify the `QueryResult.scores` field IN-PLACE so that they contain the correct scores.
5254
"""
@@ -55,7 +57,7 @@ async def compute_similarity(self, results: list[QueryResult]): # pragma: nocov
5557
async def rerank(self, results: list[QueryResult]) -> list[str]:
5658
if len(results) == 0:
5759
return []
58-
results = await self.compute_similarity(results)
60+
await self.compute_similarity(results)
5961

6062
group_by = "path"
6163
if QueryInclude.chunk in self.configs.include:

tests/subcommands/query/test_query.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vectorcode.cli_utils import CliAction, Config, QueryInclude
1010
from vectorcode.subcommands.query import (
1111
build_query_results,
12+
conver_query_results,
1213
get_query_result_files,
1314
query,
1415
)
@@ -47,7 +48,7 @@ def mock_collection():
4748
@pytest.fixture
4849
def mock_config():
4950
return Config(
50-
query=["test query"],
51+
query=["test query", "test query 2"],
5152
n_result=3,
5253
query_multiplier=2,
5354
chunk_size=100,
@@ -88,7 +89,7 @@ async def test_get_query_result_files(mock_collection, mock_config):
8889
# Check that query was called with the right parameters
8990
mock_collection.query.assert_called_once()
9091
args, kwargs = mock_collection.query.call_args
91-
mock_embedding_function.assert_called_once_with(["test query"])
92+
mock_embedding_function.assert_called_once_with(["test query", "test query 2"])
9293
assert kwargs["n_results"] == 6 # n_result(3) * query_multiplier(2)
9394
assert IncludeEnum.metadatas in kwargs["include"]
9495
assert IncludeEnum.distances in kwargs["include"]
@@ -98,7 +99,7 @@ async def test_get_query_result_files(mock_collection, mock_config):
9899
# Check reranker was used correctly
99100
mock_get_reranker.assert_called_once_with(mock_config)
100101
mock_reranker_instance.rerank.assert_called_once_with(
101-
mock_collection.query.return_value
102+
conver_query_results(mock_collection.query.return_value, mock_config.query)
102103
)
103104

104105
# Check the result
@@ -323,40 +324,6 @@ async def test_get_query_result_files_chunking(mock_collection, mock_config):
323324
assert result == ["file1.py", "file2.py"]
324325

325326

326-
@pytest.mark.asyncio
327-
async def test_get_query_result_files_multiple_queries(mock_collection, mock_config):
328-
# Set multiple query terms
329-
mock_config.query = ["term1", "term2", "term3"]
330-
mock_config.embedding_dims = 10
331-
332-
with (
333-
patch("vectorcode.subcommands.query.StringChunker") as MockChunker,
334-
patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker,
335-
):
336-
# Set up MockChunker to return the query terms as is
337-
mock_chunker_instance = MagicMock()
338-
mock_chunker_instance.chunk.side_effect = lambda q: [q]
339-
MockChunker.return_value = mock_chunker_instance
340-
341-
mock_reranker_instance = MagicMock()
342-
mock_reranker_instance.rerank = AsyncMock(return_value=["file1.py", "file2.py"])
343-
MockReranker.return_value = mock_reranker_instance
344-
345-
# Call the function
346-
result = await get_query_result_files(mock_collection, mock_config)
347-
348-
# Check that chunker was called for each query term
349-
assert mock_chunker_instance.chunk.call_count == 3
350-
351-
# Check query was called with all query terms
352-
mock_collection.query.assert_called_once()
353-
_, kwargs = mock_collection.query.call_args
354-
assert all(len(i) == 10 for i in kwargs["query_embeddings"])
355-
356-
# Check the result
357-
assert result == ["file1.py", "file2.py"]
358-
359-
360327
@pytest.mark.asyncio
361328
async def test_query_success(mock_config):
362329
# Mock all the necessary dependencies

tests/subcommands/query/test_reranker.py

Lines changed: 45 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy
55
import pytest
66

7-
from vectorcode.cli_utils import Config, QueryInclude
7+
from vectorcode.cli_utils import Config
88
from vectorcode.subcommands.query.reranker import (
99
CrossEncoderReranker,
1010
NaiveReranker,
@@ -14,6 +14,7 @@
1414
get_available_rerankers,
1515
get_reranker,
1616
)
17+
from vectorcode.subcommands.query.types import QueryResult
1718

1819

1920
@pytest.fixture(scope="function")
@@ -37,29 +38,50 @@ def naive_reranker_conf():
3738

3839

3940
@pytest.fixture(scope="function")
40-
def query_result():
41-
return {
42-
"ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
43-
"distances": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
44-
"metadatas": [
45-
[{"path": "file1.py"}, {"path": "file2.py"}, {"path": "file3.py"}],
46-
[{"path": "file2.py"}, {"path": "file4.py"}, {"path": "file3.py"}],
47-
],
48-
"documents": [
49-
["content1", "content2", "content3"],
50-
["content4", "content5", "content6"],
51-
],
52-
}
41+
def query_result() -> list[QueryResult]:
42+
return [
43+
QueryResult(
44+
path="file1.py",
45+
chunk=MagicMock(),
46+
query=("query chunk 1",),
47+
scores=(0.5,),
48+
),
49+
QueryResult(
50+
path="file2.py",
51+
chunk=MagicMock(),
52+
query=("query chunk 1",),
53+
scores=(0.9,),
54+
),
55+
QueryResult(
56+
path="file3.py",
57+
chunk=MagicMock(),
58+
query=("query chunk 1",),
59+
scores=(0.3,),
60+
),
61+
QueryResult(
62+
path="file2.py",
63+
chunk=MagicMock(),
64+
query=("query chunk 2",),
65+
scores=(0.6,),
66+
),
67+
QueryResult(
68+
path="file4.py",
69+
chunk=MagicMock(),
70+
query=("query chunk 2",),
71+
scores=(0.7,),
72+
),
73+
QueryResult(
74+
path="file3.py",
75+
chunk=MagicMock(),
76+
query=("query chunk 2",),
77+
scores=(0.2,),
78+
),
79+
]
5380

5481

5582
@pytest.fixture(scope="function")
5683
def empty_query_result():
57-
return {
58-
"ids": [],
59-
"distances": [],
60-
"metadatas": [],
61-
"documents": [],
62-
}
84+
return []
6385

6486

6587
@pytest.fixture(scope="function")
@@ -103,8 +125,8 @@ async def test_naive_reranker_rerank(naive_reranker_conf, query_result):
103125
assert len(result) <= naive_reranker_conf.n_result
104126

105127
# Check all returned items are strings (paths)
106-
for path in result:
107-
assert isinstance(path, str)
128+
for res in result:
129+
assert isinstance(res, str)
108130

109131

110132
@pytest.mark.asyncio
@@ -143,21 +165,7 @@ async def test_cross_encoder_reranker_rerank(mock_cross_encoder, config, query_r
143165
mock_model = MagicMock()
144166
mock_cross_encoder.return_value = mock_model
145167

146-
# Configure mock predict to return numpy array with float32 dtype
147-
scores = numpy.array([0.9, 0.7, 0.8], dtype=numpy.float32)
148-
mock_model.predict.return_value = scores
149-
150-
# Ensure complete query_result structure
151-
query_result.update(
152-
{
153-
"ids": [["id1", "id2", "id3"], ["id4", "id5", "id6"]],
154-
"documents": [["doc1", "doc2", "doc3"], ["doc4", "doc5", "doc6"]],
155-
"metadatas": [
156-
[{"path": "p1"}, {"path": "p2"}, {"path": "p3"}],
157-
[{"path": "p4"}, {"path": "p5"}, {"path": "p6"}],
158-
],
159-
}
160-
)
168+
mock_model.predict = lambda x: numpy.random.random((len(x),))
161169

162170
reranker = CrossEncoderReranker(config)
163171
result = await reranker.rerank(query_result)
@@ -184,46 +192,6 @@ async def test_naive_reranker_document_selection_logic(
184192
assert "file2.py" in result or "file3.py" in result
185193

186194

187-
@pytest.mark.asyncio
188-
async def test_naive_reranker_with_chunk_ids(naive_reranker_conf, query_result):
189-
"""Test NaiveReranker returns chunk IDs when QueryInclude.chunk is set"""
190-
naive_reranker_conf.include.append(
191-
QueryInclude.chunk
192-
) # Assuming QueryInclude.chunk would be "chunk"
193-
194-
reranker = NaiveReranker(naive_reranker_conf)
195-
result = await reranker.rerank(query_result)
196-
197-
assert isinstance(result, list)
198-
assert len(result) <= naive_reranker_conf.n_result
199-
assert all(isinstance(id, str) for id in result)
200-
assert all(id.startswith("id") for id in result) # Verify IDs not paths
201-
202-
203-
@pytest.mark.asyncio
204-
@patch("sentence_transformers.CrossEncoder")
205-
async def test_cross_encoder_reranker_with_chunk_ids(
206-
mock_cross_encoder, config, query_result
207-
):
208-
"""Test CrossEncoderReranker returns chunk IDs when QueryInclude.chunk is set"""
209-
mock_model = MagicMock()
210-
mock_cross_encoder.return_value = mock_model
211-
212-
# Setup mock to return numpy array scores
213-
scores = numpy.array([0.9, 0.7], dtype=numpy.float32)
214-
mock_model.predict.return_value = scores
215-
216-
config.include = {QueryInclude.chunk}
217-
reranker = CrossEncoderReranker(config)
218-
219-
result = await reranker.rerank(query_result)
220-
221-
mock_model.predict.assert_called()
222-
assert isinstance(result, list)
223-
assert all(isinstance(id, str) for id in result)
224-
assert all(id in ["id1", "id2", "id3", "id4"] for id in result)
225-
226-
227195
def test_get_reranker(config, naive_reranker_conf):
228196
assert get_reranker(naive_reranker_conf).configs.reranker == "NaiveReranker"
229197

0 commit comments

Comments
 (0)