Skip to content

Commit 91838b9

Browse files
author
Zhe Yu
committed
refactor(db): Implement get_chunks method in the database connectors
1 parent a303595 commit 91838b9

2 files changed

Lines changed: 50 additions & 14 deletions

File tree

src/vectorcode/database/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from abc import ABC, abstractmethod
55
from typing import Optional, Self, Sequence
66

7-
from chromadb import EmbeddingFunction
87
from numpy.typing import NDArray
98

10-
from vectorcode.chunking import TreeSitterChunker
9+
from vectorcode.chunking import Chunk, TreeSitterChunker
1110
from vectorcode.cli_utils import Config
1211
from vectorcode.common import get_embedding_function
1312
from vectorcode.database.types import (
@@ -83,7 +82,6 @@ async def vectorise(
8382
self,
8483
file_path: str,
8584
chunker: TreeSitterChunker | None = None,
86-
embedding_function: EmbeddingFunction | None = None,
8785
) -> VectoriseStats:
8886
"""
8987
Vectorise the given file and add it to the database.
@@ -187,3 +185,11 @@ def get_embedding(self, texts: str | list[str]) -> list[NDArray]:
187185
if self._configs.embedding_dims:
188186
embeddings = [e[: self._configs.embedding_dims] for e in embeddings]
189187
return embeddings
188+
189+
@abstractmethod
190+
async def get_chunks(self, file_path) -> list[Chunk]:
191+
"""
192+
Return chunks for the provided file, if any.
193+
If not found, return an empty list.
194+
"""
195+
pass

src/vectorcode/database/chroma0.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import httpx
1616
from chromadb.api import AsyncClientAPI
1717
from chromadb.api.models.AsyncCollection import AsyncCollection
18-
from chromadb.api.types import EmbeddingFunction, IncludeEnum, QueryResult
18+
from chromadb.api.types import IncludeEnum, QueryResult
1919
from chromadb.config import APIVersion, Settings
2020
from chromadb.errors import InvalidCollectionException
2121
from tree_sitter import Point
@@ -28,7 +28,6 @@
2828
expand_globs,
2929
expand_path,
3030
)
31-
from vectorcode.common import get_embedding_function
3231
from vectorcode.database.base import DatabaseConnectorBase
3332
from vectorcode.database.errors import CollectionNotFoundError
3433
from vectorcode.database.types import (
@@ -355,19 +354,15 @@ async def vectorise(
355354
self,
356355
file_path: str,
357356
chunker: TreeSitterChunker | None = None,
358-
embedding_function: EmbeddingFunction | None = None,
359357
) -> VectoriseStats:
360358
collection_path = str(self._configs.project_root)
361359
collection = await self._create_or_get_collection(
362360
collection_path, allow_create=True
363361
)
364362
chunker = chunker or TreeSitterChunker(self._configs)
365-
embedding_function = cast(
366-
EmbeddingFunction,
367-
embedding_function or get_embedding_function(self._configs),
368-
)
363+
369364
chunks = tuple(chunker.chunk(file_path))
370-
embeddings = embedding_function(list(i.text for i in chunks))
365+
embeddings = self.get_embedding(list(i.text for i in chunks))
371366

372367
file_hash = hash_file(file_path)
373368

@@ -414,7 +409,7 @@ async def list_collections(self):
414409
for col_name in await client.list_collections():
415410
col = await client.get_collection(col_name)
416411
project_root = str(col.metadata.get("path"))
417-
col_counts = await self.list()
412+
col_counts = await self.list_collection_content()
418413
result.append(
419414
CollectionInfo(
420415
id=col_name,
@@ -430,7 +425,7 @@ async def list_collections(self):
430425
)
431426
return result
432427

433-
async def list(self, what=None) -> CollectionContent:
428+
async def list_collection_content(self, what=None) -> CollectionContent:
434429
"""
435430
When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`.
436431
Otherwise, this method may populate only one of them to save waiting time.
@@ -494,7 +489,7 @@ async def delete(self) -> int:
494489
]
495490
files_in_collection = set(
496491
str(expand_path(i.path, True))
497-
for i in (await self.list(ResultType.document)).files
492+
for i in (await self.list_collection_content(ResultType.document)).files
498493
)
499494

500495
rm_paths = {
@@ -516,3 +511,38 @@ async def drop(
516511
async with _Chroma0ClientManager().get_client(self._configs) as client:
517512
await self._create_or_get_collection(collection_path, False)
518513
await client.delete_collection(get_collection_id(collection_path))
514+
515+
async def get_chunks(self, file_path) -> list[Chunk]:
516+
file_path = os.path.abspath(file_path)
517+
try:
518+
collection = await self._create_or_get_collection(
519+
collection_path=str(self._configs.project_root), allow_create=False
520+
)
521+
except CollectionNotFoundError:
522+
_logger.warning(
523+
f"There's no existing collection at {self._configs.project_root}."
524+
)
525+
return []
526+
except Exception:
527+
raise
528+
529+
raw_results = await collection.get(
530+
where={"path": file_path},
531+
include=[IncludeEnum.metadatas, IncludeEnum.documents],
532+
)
533+
assert raw_results["metadatas"] is not None
534+
assert raw_results["documents"] is not None
535+
536+
result: list[Chunk] = []
537+
for i in range(len(raw_results["ids"])):
538+
meta = raw_results["metadatas"][i]
539+
text = raw_results["documents"][i]
540+
_id = raw_results["ids"][i]
541+
chunk = Chunk(text=text, id=_id)
542+
if meta.get("start") is not None:
543+
chunk.start = Point(row=int(meta["start"]), column=0)
544+
if meta.get("end") is not None:
545+
chunk.end = Point(row=int(meta["end"]), column=0)
546+
547+
result.append(chunk)
548+
return result

0 commit comments

Comments
 (0)