Skip to content

Commit a303595

Browse files
author
Zhe Yu
committed
feat(db): Refactor database connectors to use embeddings and improve listing
1 parent f389f09 commit a303595

3 files changed

Lines changed: 21 additions & 6 deletions

File tree

src/vectorcode/database/base.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from vectorcode.chunking import TreeSitterChunker
1111
from vectorcode.cli_utils import Config
12+
from vectorcode.common import get_embedding_function
1213
from vectorcode.database.types import (
1314
CollectionContent,
1415
CollectionInfo,
@@ -64,7 +65,7 @@ async def count(self, what: ResultType = ResultType.chunk) -> int:
6465
"""
6566
Returns the chunk count or file count of the given collection, depending on the value passed for `what`.
6667
"""
67-
collection_content = await self.list(what)
68+
collection_content = await self.list_collection_content(what)
6869
match what:
6970
case ResultType.chunk:
7071
return len(collection_content.chunks)
@@ -74,7 +75,6 @@ async def count(self, what: ResultType = ResultType.chunk) -> int:
7475
@abstractmethod
7576
async def query(
7677
self,
77-
keywords_embeddings: list[NDArray],
7878
) -> Sequence[QueryResult]:
7979
pass
8080

@@ -99,7 +99,9 @@ async def list_collections(self) -> Sequence[CollectionInfo]:
9999
pass
100100

101101
@abstractmethod
102-
async def list(self, what: Optional[ResultType] = None) -> CollectionContent:
102+
async def list_collection_content(
103+
self, what: Optional[ResultType] = None
104+
) -> CollectionContent:
103105
"""
104106
List the content of a collection (from `self._configs.project_root`).
105107
@@ -163,7 +165,7 @@ async def check_orphanes(self) -> int:
163165
"""
164166

165167
orphanes: list[str] = []
166-
database_files = (await self.list(ResultType.document)).files
168+
database_files = (await self.list_collection_content(ResultType.document)).files
167169
for file in database_files:
168170
path = file.path
169171
if not os.path.isfile(path):
@@ -174,3 +176,14 @@ async def check_orphanes(self) -> int:
174176
await self.delete()
175177

176178
return len(orphanes)
179+
180+
def get_embedding(self, texts: str | list[str]) -> list[NDArray]:
181+
"""
182+
Generate embeddings and truncate them to `self._configs.embedding_dims` if needed.
183+
"""
184+
if isinstance(texts, str):
185+
texts = [texts]
186+
embeddings = get_embedding_function(self._configs)(texts)
187+
if self._configs.embedding_dims:
188+
embeddings = [e[: self._configs.embedding_dims] for e in embeddings]
189+
return embeddings

src/vectorcode/database/chroma0.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,14 @@ def __init__(self, configs: Config):
267267
params.update(self._configs.db_params)
268268
self._configs.db_params = params
269269

270-
async def query(self, keywords_embeddings):
270+
async def query(self):
271271
assert self._configs.query is not None
272272
assert len(self._configs.query), "Keywords cannot be empty"
273+
keywords_embeddings = self.get_embedding(self._configs.query)
273274
assert len(keywords_embeddings) == len(self._configs.query), (
274275
"Number of embeddings must match number of keywords."
275276
)
277+
276278
collection_path = str(self._configs.project_root)
277279
collection: AsyncCollection = await self._create_or_get_collection(
278280
collection_path=collection_path, allow_create=False

src/vectorcode/subcommands/files/ls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
async def ls(configs: Config) -> int:
1111
database = get_database_connector(configs)
12-
files = list(i.path for i in (await database.list()).files)
12+
files = list(i.path for i in (await database.list_collection_content()).files)
1313
if configs.pipe:
1414
print(json.dumps(files))
1515
else:

0 commit comments

Comments
 (0)