Skip to content

Commit 9b3763b

Browse files
author
Zhe Yu
committed
refactor(cli): Refactor mcp_main to use DB adapter layer
1 parent cd43fca commit 9b3763b

1 file changed

Lines changed: 69 additions & 104 deletions

File tree

src/vectorcode/mcp_main.py

Lines changed: 69 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
import traceback
77
from dataclasses import dataclass
88
from pathlib import Path
9-
from typing import Optional, cast
9+
from typing import Optional
1010

1111
import shtab
12-
from chromadb.types import Where
1312

13+
from vectorcode.chunking import StringChunker
14+
from vectorcode.database import get_database_connector
15+
from vectorcode.database.types import ResultType
1416
from vectorcode.subcommands.vectorise import (
17+
FilterManager,
1518
VectoriseStats,
16-
chunked_add,
17-
exclude_paths_by_spec,
1819
find_exclude_specs,
19-
remove_orphanes,
20+
vectorise_worker,
2021
)
2122

2223
try: # pragma: nocover
@@ -32,7 +33,7 @@
3233
from vectorcode.cli_utils import (
3334
Config,
3435
LockManager,
35-
cleanup_path,
36+
SpecResolver,
3637
config_logging,
3738
expand_globs,
3839
expand_path,
@@ -42,12 +43,12 @@
4243
)
4344
from vectorcode.common import (
4445
ClientManager,
45-
get_collection,
46-
get_collections,
47-
list_collection_files,
4846
)
4947
from vectorcode.subcommands.prompt import prompt_by_categories
50-
from vectorcode.subcommands.query import get_query_result_files
48+
from vectorcode.subcommands.query import (
49+
_prepare_formatted_result,
50+
get_reranked_results,
51+
)
5152

5253
logger = logging.getLogger(name=__name__)
5354
locks = LockManager()
@@ -91,15 +92,12 @@ def get_arg_parser():
9192

9293

9394
async def list_collections() -> list[str]:
94-
names: list[str] = []
95-
async with ClientManager().get_client(
96-
await load_config_file(default_project_root)
97-
) as client:
98-
async for col in get_collections(client):
99-
if col.metadata is not None:
100-
names.append(cleanup_path(str(col.metadata.get("path"))))
101-
logger.info("Retrieved the following collections: %s", names)
102-
return names
95+
"""
96+
Returns a list of paths to the projects that have been indexed in the database.
97+
"""
98+
99+
config = await load_config_file(default_project_root)
100+
return [i.path for i in await get_database_connector(config).list_collections()]
103101

104102

105103
async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]:
@@ -113,52 +111,38 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]
113111
ErrorData(code=1, message=f"{project_root} is not a valid path.")
114112
)
115113
config = await get_project_config(project_root)
114+
115+
paths = [os.path.expanduser(i) for i in await expand_globs(paths)]
116+
final_config = await config.merge_from(
117+
Config(
118+
files=[i for i in paths if os.path.isfile(i)],
119+
project_root=project_root,
120+
)
121+
)
122+
filters = FilterManager()
123+
for ignore_spec_file in find_exclude_specs(final_config):
124+
if os.path.isfile(ignore_spec_file):
125+
logger.info(f"Loading ignore specs from {ignore_spec_file}.")
126+
spec = SpecResolver.from_path(ignore_spec_file)
127+
filters.add_filter(lambda x: spec.match_file(x, True))
128+
129+
final_config.files = list(filters(paths))
130+
131+
database = get_database_connector(final_config)
116132
try:
117-
async with ClientManager().get_client(config) as client:
118-
collection = await get_collection(client, config, True)
119-
if collection is None: # pragma: nocover
120-
raise McpError(
121-
ErrorData(
122-
code=1,
123-
message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
124-
)
125-
)
126-
paths = [os.path.expanduser(i) for i in await expand_globs(paths)]
127-
final_config = await config.merge_from(
128-
Config(
129-
files=[i for i in paths if os.path.isfile(i)],
130-
project_root=project_root,
131-
)
133+
stats = VectoriseStats()
134+
stats_lock = asyncio.Lock()
135+
semaphore = asyncio.Semaphore(os.cpu_count() or 1)
136+
tasks = [
137+
asyncio.create_task(
138+
vectorise_worker(database, file, semaphore, stats, stats_lock)
132139
)
133-
for ignore_spec in find_exclude_specs(final_config):
134-
if os.path.isfile(ignore_spec):
135-
logger.info(f"Loading ignore specs from {ignore_spec}.")
136-
paths = exclude_paths_by_spec((str(i) for i in paths), ignore_spec)
137-
138-
stats = VectoriseStats()
139-
collection_lock = asyncio.Lock()
140-
stats_lock = asyncio.Lock()
141-
max_batch_size = await client.get_max_batch_size()
142-
semaphore = asyncio.Semaphore(os.cpu_count() or 1)
143-
tasks = [
144-
asyncio.create_task(
145-
chunked_add(
146-
str(file),
147-
collection,
148-
collection_lock,
149-
stats,
150-
stats_lock,
151-
final_config,
152-
max_batch_size,
153-
semaphore,
154-
)
155-
)
156-
for file in paths
157-
]
158-
for i, task in enumerate(asyncio.as_completed(tasks), start=1):
159-
await task
140+
for file in paths
141+
]
142+
for i, task in enumerate(asyncio.as_completed(tasks), start=1):
143+
await task
160144

161-
await remove_orphanes(collection, collection_lock, stats, stats_lock)
145+
await database.check_orphanes()
162146

163147
return stats.to_dict()
164148
except Exception as e: # pragma: nocover
@@ -195,36 +179,15 @@ async def query_tool(
195179
)
196180
)
197181
config = await get_project_config(project_root)
182+
config.query = []
183+
chunker = StringChunker(config)
184+
for message in query_messages:
185+
config.query.extend(str(i) for i in chunker.chunk(message))
186+
config.n_result = n_query
198187
try:
199-
async with ClientManager().get_client(config) as client:
200-
collection = await get_collection(client, config, False)
201-
202-
if collection is None: # pragma: nocover
203-
raise McpError(
204-
ErrorData(
205-
code=1,
206-
message=f"Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
207-
)
208-
)
209-
query_config = await config.merge_from(
210-
Config(n_result=n_query, query=query_messages)
211-
)
212-
logger.info("Built the final config: %s", query_config)
213-
result_paths = await get_query_result_files(
214-
collection=collection,
215-
configs=query_config,
216-
)
217-
results: list[str] = []
218-
for result in result_paths:
219-
if isinstance(result, str):
220-
if os.path.isfile(result):
221-
with open(result) as fin:
222-
rel_path = os.path.relpath(result, config.project_root)
223-
results.append(
224-
f"<path>{rel_path}</path>\n<content>{fin.read()}</content>",
225-
)
226-
logger.info("Retrieved the following files: %s", result_paths)
227-
return results
188+
database = get_database_connector(config)
189+
reranked_results = await get_reranked_results(config, database)
190+
return list(str(i) for i in _prepare_formatted_result(reranked_results))
228191

229192
except Exception as e: # pragma: nocover
230193
if isinstance(e, McpError):
@@ -244,8 +207,13 @@ async def ls_files(project_root: str) -> list[str]:
244207
project_root: Directory to the repository. MUST be from the vectorcode `ls` tool or user input;
245208
"""
246209
configs = await get_project_config(expand_path(project_root, True))
247-
async with ClientManager().get_client(configs) as client:
248-
return await list_collection_files(await get_collection(client, configs, False))
210+
database = get_database_connector(configs)
211+
return list(
212+
i.path
213+
for i in (
214+
await database.list_collection_content(what=ResultType.document)
215+
).files
216+
)
249217

250218

251219
async def rm_files(files: list[str], project_root: str):
@@ -254,17 +222,14 @@ async def rm_files(files: list[str], project_root: str):
254222
project_root: Directory to the repository. MUST be from the vectorcode `ls` tool or user input;
255223
"""
256224
configs = await get_project_config(expand_path(project_root, True))
257-
async with ClientManager().get_client(configs) as client:
258-
try:
259-
collection = await get_collection(client, configs, False)
260-
files = [str(expand_path(i, True)) for i in files if os.path.isfile(i)]
261-
if files:
262-
await collection.delete(where=cast(Where, {"path": {"$in": files}}))
263-
else: # pragma: nocover
264-
logger.warning(f"All paths were invalid: {files}")
265-
except ValueError: # pragma: nocover
266-
logger.warning(f"Failed to find the collection at {configs.project_root}")
267-
return
225+
configs.rm_paths = [str(expand_path(i, True)) for i in files if os.path.isfile(i)]
226+
227+
if configs.rm_paths:
228+
database = get_database_connector(configs)
229+
num_deleted = await database.delete()
230+
return f"Removed {num_deleted} files from the database of the project located at {project_root}"
231+
else:
232+
logger.warning(f"The provided paths were invalid: {configs.rm_paths}")
268233

269234

270235
async def mcp_server():

0 commit comments

Comments
 (0)