Skip to content

Commit af1454d

Browse files
author
Zhe Yu
committed
feat(cli): Refactor vectorise command to use DB adapter layer
1 parent 9701e18 commit af1454d

5 files changed

Lines changed: 86 additions & 16 deletions

File tree

src/vectorcode/cli_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -708,27 +708,29 @@ def from_path(cls, spec_path: str, project_root: Optional[str] = None):
708708
return cls(spec_path, base_dir)
709709

710710
def __init__(self, spec: str | GitIgnoreSpec, base_dir: str = "."):
711+
self.spec: GitIgnoreSpec
711712
if isinstance(spec, str):
712713
with open(spec) as fin:
713714
self.spec = GitIgnoreSpec.from_lines(
714715
(i.strip() for i in fin.readlines())
715716
)
716717
else:
717718
self.spec = spec
718-
self.base_dir = base_dir
719+
self.base_dir = Path(base_dir).resolve()
720+
721+
def match_file(self, path: str, negated: bool = False) -> bool:
722+
if self.base_dir in Path(path).resolve().parents:
723+
matched = self.spec.match_file(os.path.relpath(path, self.base_dir))
724+
if negated:
725+
matched = not matched
726+
return matched
727+
return True
719728

720729
def match(
721730
self, paths: Iterable[str], negated: bool = False
722731
) -> Generator[str, None, None]:
723732
# get paths relative to `base_dir`
724733

725-
base = Path(self.base_dir).resolve()
726734
for p in paths:
727-
if base in Path(p).resolve().parents:
728-
should_yield = self.spec.match_file(os.path.relpath(p, self.base_dir))
729-
if negated:
730-
should_yield = not should_yield
731-
if should_yield:
732-
yield p
733-
else:
735+
if self.match_file(p, negated):
734736
yield p

src/vectorcode/database/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def get_embedding(self, texts: str | list[str]) -> list[NDArray]:
180180
"""
181181
if isinstance(texts, str):
182182
texts = [texts]
183+
if len(texts) == 0:
184+
return []
185+
texts = [i for i in texts]
186+
logger.debug(f"Getting embeddings for {texts}")
183187
embeddings = get_embedding_function(self._configs)(texts)
184188
if self._configs.embedding_dims:
185189
embeddings = [e[: self._configs.embedding_dims] for e in embeddings]

src/vectorcode/database/chroma0.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ async def vectorise(
363363

364364
chunks = tuple(chunker.chunk(file_path))
365365
embeddings = self.get_embedding(list(i.text for i in chunks))
366+
if len(embeddings) == 0:
367+
return VectoriseStats(skipped=1)
366368

367369
file_hash = hash_file(file_path)
368370

@@ -501,7 +503,9 @@ async def delete(self) -> int:
501503
]
502504
files_in_collection = set(
503505
str(expand_path(i.path, True))
504-
for i in (await self.list_collection_content(ResultType.document)).files
506+
for i in (
507+
await self.list_collection_content(what=ResultType.document)
508+
).files
505509
)
506510

507511
rm_paths = {

src/vectorcode/subcommands/vectorise.py renamed to src/vectorcode/subcommands/vectorise/__init__.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
)
2929
from vectorcode.database import get_database_connector
3030
from vectorcode.database.base import DatabaseConnectorBase
31-
from vectorcode.database.types import VectoriseStats
31+
from vectorcode.database.errors import CollectionNotFoundError
32+
from vectorcode.database.types import ResultType, VectoriseStats
33+
from vectorcode.subcommands.vectorise.filter import FilterManager
3234

3335
logger = logging.getLogger(name=__name__)
3436

@@ -261,19 +263,31 @@ async def vectorise(configs: Config) -> int:
261263
include_hidden=configs.include_hidden,
262264
)
263265

264-
# TODO: check file hashes
266+
filters = FilterManager()
267+
268+
try:
269+
collection_files = (
270+
await database.list_collection_content(what=ResultType.document)
271+
).files
272+
273+
existing_hashes = set(i.sha256 for i in collection_files)
274+
except CollectionNotFoundError:
275+
existing_hashes = set()
265276

266277
if not configs.force:
267278
for spec_path in find_exclude_specs(configs):
279+
# filter by gitignore/vectorcode.exclude
268280
if os.path.isfile(spec_path):
269281
logger.info(f"Loading ignore specs from {spec_path}.")
270-
files = exclude_paths_by_spec(
271-
(str(i) for i in files), spec_path, str(configs.project_root)
272-
)
273-
logger.debug(f"Files after excluding: {files}")
282+
spec = SpecResolver.from_path(spec_path)
283+
filters.add_filter(lambda x: spec.match_file(x, True))
284+
285+
# filter by sha256
286+
filters.add_filter(lambda x: hash_file(x) not in existing_hashes)
274287
else: # pragma: nocover
275288
logger.info("Ignoring exclude specs.")
276289

290+
files = list(filters(files))
277291
stats = VectoriseStats()
278292
stats_lock = Lock()
279293
semaphore = asyncio.Semaphore(os.cpu_count() or 1)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import logging
2+
import os
3+
import sys
4+
from typing import Callable, Iterable, Self, Sequence
5+
6+
logger = logging.getLogger(name=__name__)
7+
8+
FileFilter = Callable[[str], bool]
9+
10+
11+
class FilterManager:
12+
def __init__(self, from_filters: Sequence[FileFilter] | None = None) -> None:
13+
self._filters: list[FileFilter] = []
14+
if from_filters:
15+
self._filters.extend(from_filters)
16+
17+
def add_filter(self, f: FileFilter = lambda x: bool(x)) -> Self:
18+
self._filters.append(f)
19+
return self
20+
21+
def _has_debugging(self): # pragma: nocover
22+
"""
23+
Iterators are difficult to debug.
24+
Use this function to decide whether we should convert iterators to tuples
25+
to make debugging easier.
26+
"""
27+
return (
28+
sys.gettrace() is not None
29+
or os.environ.get("VECTORCODE_LOG_LEVEL") is not None
30+
)
31+
32+
def __call__(self, files: Iterable[str]) -> Iterable[str]:
33+
if self._has_debugging(): # pragma: nocover
34+
files = tuple(files)
35+
logger.debug(
36+
f"Applying the following filters: {list(i.__name__ for i in self._filters)} to the following files ({len(files)}): {files}"
37+
)
38+
39+
for f in self._filters:
40+
files = filter(f, files)
41+
42+
if self._has_debugging(): # pragma: nocover
43+
files = tuple(files)
44+
logger.debug(f"{f.__name__} remaining items ({len(files)}): {files}")
45+
46+
return files

0 commit comments

Comments
 (0)