Skip to content

Commit 0e1c9a5

Browse files
author
Zhe Yu
committed
feat(db): Improve database abstraction and vectorisation process
1 parent 368b57a commit 0e1c9a5

4 files changed

Lines changed: 187 additions & 102 deletions

File tree

src/vectorcode/database/base.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from abc import ABC, abstractmethod
33
from typing import Optional, Sequence
44

5+
from chromadb import EmbeddingFunction
56
from numpy.typing import NDArray
67

7-
from vectorcode.chunking import Chunk
8+
from vectorcode.chunking import TreeSitterChunker
89
from vectorcode.cli_utils import Config
910
from vectorcode.database.types import (
1011
CollectionContent,
11-
CollectionID,
1212
CollectionInfo,
1313
QueryOpts,
1414
ResultType,
@@ -38,10 +38,10 @@ def __init__(self, configs: Config):
3838
self._configs = configs
3939

4040
async def count(
41-
self, collection_id: CollectionID, what: ResultType = ResultType.chunk
41+
self, collection_path: str, what: ResultType = ResultType.chunk
4242
) -> int:
43-
"""Returns the chunk count or"""
44-
collection_content = await self.list(collection_id, what)
43+
"""Returns the chunk count or file count of the given collection, depending on the value passed for `what`."""
44+
collection_content = await self.list(collection_path, what)
4545
match what:
4646
case ResultType.chunk:
4747
return len(collection_content.chunks)
@@ -51,7 +51,7 @@ async def count(
5151
@abstractmethod
5252
async def query(
5353
self,
54-
collection_id: CollectionID,
54+
collection_path: str,
5555
keywords_embeddings: list[NDArray],
5656
opts: QueryOpts,
5757
) -> Sequence[QueryResult]:
@@ -60,11 +60,15 @@ async def query(
6060
@abstractmethod
6161
async def vectorise(
6262
self,
63-
collection_id: CollectionID,
64-
chunks: Sequence[Chunk],
65-
chunk_embeddings: Sequence[NDArray],
66-
file_hashes: Sequence[str],
63+
collection_path: str,
64+
file_path: str,
65+
chunker: TreeSitterChunker | None = None,
66+
embedding_function: EmbeddingFunction | None = None,
6767
) -> VectoriseStats:
68+
"""
69+
Vectorise the given file and add it to the database.
70+
The duplicate checking (using file hash) should be done outside of this function.
71+
"""
6872
pass
6973

7074
@abstractmethod
@@ -73,7 +77,7 @@ async def list_collections(self) -> Sequence[CollectionInfo]:
7377

7478
@abstractmethod
7579
async def list(
76-
self, collection_id: CollectionID, what: Optional[ResultType] = None
80+
self, collection_path: str, what: Optional[ResultType] = None
7781
) -> CollectionContent:
7882
"""
7983
When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`.

src/vectorcode/database/chroma0.py

Lines changed: 119 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,21 @@
88
import sys
99
from asyncio.subprocess import Process
1010
from dataclasses import dataclass
11-
from typing import Any, Optional
11+
from typing import Any, Optional, cast
1212
from urllib.parse import urlparse
1313

1414
import chromadb
1515
import httpx
1616
from chromadb.api import AsyncClientAPI
1717
from chromadb.api.models.AsyncCollection import AsyncCollection
18-
from chromadb.api.types import IncludeEnum, QueryResult
18+
from chromadb.api.types import EmbeddingFunction, IncludeEnum, QueryResult
1919
from chromadb.config import APIVersion, Settings
2020
from tree_sitter import Point
2121

2222
import vectorcode.subcommands.query.types as vectorcode_query_types
23-
from vectorcode.chunking import Chunk
23+
from vectorcode.chunking import Chunk, TreeSitterChunker
2424
from vectorcode.cli_utils import Config, LockManager, expand_path
25+
from vectorcode.common import get_embedding_function
2526
from vectorcode.database.base import DatabaseConnectorBase
2627
from vectorcode.database.types import (
2728
CollectionContent,
@@ -30,6 +31,7 @@
3031
ResultType,
3132
VectoriseStats,
3233
)
34+
from vectorcode.database.utils import get_collection_id, hash_file
3335
from vectorcode.subcommands.vectorise import get_uuid
3436

3537
logger = logging.getLogger(name=__name__)
@@ -248,33 +250,36 @@ def __init__(self, configs: Config):
248250
params.update(self._configs.db_params)
249251
self._configs.db_params = params
250252

251-
async def query(self, collection_id, keywords_embeddings, opts):
253+
async def query(self, collection_path, keywords_embeddings, opts):
252254
assert len(opts.keywords), "Keywords cannot be empty"
253255
assert len(keywords_embeddings) == len(opts.keywords), (
254256
"Number of embeddings must match number of keywords."
255257
)
256-
async with Chroma0ClientManager().get_client(self._configs, False) as client:
257-
collection = await client.get_collection(collection_id)
258-
query_count = opts.count or (
259-
await self.count(collection_id, ResultType.chunk)
260-
)
261-
query_result = await collection.query(
262-
query_embeddings=keywords_embeddings,
263-
include=[
264-
IncludeEnum.metadatas,
265-
IncludeEnum.documents,
266-
IncludeEnum.distances,
267-
],
268-
n_results=query_count,
269-
)
270-
return __convert_chroma_query_results(query_result, opts.keywords)
258+
collection: AsyncCollection = await self._create_or_get_collection(
259+
collection_path=collection_path, allow_create=False
260+
)
261+
query_count = opts.count or (
262+
await self.count(collection_path, ResultType.chunk)
263+
)
264+
query_result = await collection.query(
265+
query_embeddings=keywords_embeddings,
266+
include=[
267+
IncludeEnum.metadatas,
268+
IncludeEnum.documents,
269+
IncludeEnum.distances,
270+
],
271+
n_results=query_count,
272+
)
273+
return __convert_chroma_query_results(query_result, opts.keywords)
271274

272-
async def _create_or_get_collection(self, collection_id) -> AsyncCollection:
275+
async def _create_or_get_collection(
276+
self, collection_path: str, allow_create: bool = False
277+
) -> AsyncCollection:
273278
"""
274279
This method should be used by ChromaDB methods that are expected to **create a collection when not found**.
275280
For other methods, just use `client.get_collection` and let it fail if the collection doesn't exist.
276281
"""
277-
assert self._configs.project_root is not None
282+
278283
collection_meta: dict[str, str | int] = {
279284
"path": os.path.abspath(str(self._configs.project_root)),
280285
"hostname": socket.gethostname(),
@@ -292,6 +297,9 @@ async def _create_or_get_collection(self, collection_id) -> AsyncCollection:
292297
collection_meta[meta_field_name] = db_params[key]
293298

294299
async with Chroma0ClientManager().get_client(self._configs, True) as client:
300+
collection_id = get_collection_id(collection_path)
301+
if not allow_create:
302+
return await client.get_collection(collection_id)
295303
col = await client.get_or_create_collection(
296304
collection_id, metadata=collection_meta
297305
)
@@ -303,39 +311,60 @@ async def _create_or_get_collection(self, collection_id) -> AsyncCollection:
303311

304312
return col
305313

306-
async def vectorise(self, collection_id, chunks, chunk_embeddings, file_hashes):
307-
# WIP: finish the stats.
308-
# should this method handle chunking and hash checking?
309-
stats = VectoriseStats()
310-
311-
async with Chroma0ClientManager().get_client(self._configs, True) as client:
312-
collection = await self._create_or_get_collection(collection_id)
313-
max_batch_size = await client.get_max_batch_size()
314-
for idx_batch in range(0, len(chunks), max_batch_size):
315-
this_batch: list[Chunk] = chunks[idx_batch : idx_batch + max_batch_size]
316-
metadatas = []
317-
for idx in range(len(this_batch)):
318-
this_chunk = this_batch[idx]
319-
meta: dict[str, str | int] = {
320-
"path": str(this_chunk.path),
321-
}
322-
if file_hashes:
323-
for idx in range(idx_batch, idx_batch + max_batch_size):
324-
meta["sha256"] = file_hashes[idx]
325-
if this_chunk.start and isinstance(this_chunk.start.row, int):
326-
meta["start"] = this_chunk.start.row
327-
if this_chunk.end and isinstance(this_chunk.end.row, int):
328-
meta["end"] = this_chunk.end.row
329-
330-
metadatas.append(meta)
331-
314+
async def vectorise(
315+
self,
316+
collection_path: str,
317+
file_path: str,
318+
chunker: TreeSitterChunker | None = None,
319+
embedding_function: EmbeddingFunction | None = None,
320+
) -> VectoriseStats:
321+
collection = await self._create_or_get_collection(
322+
collection_path, allow_create=True
323+
)
324+
chunker = chunker or TreeSitterChunker(self._configs)
325+
embedding_function = cast(
326+
EmbeddingFunction,
327+
embedding_function or get_embedding_function(self._configs),
328+
)
329+
chunks = tuple(chunker.chunk(file_path))
330+
embeddings = embedding_function(list(i.text for i in chunks))
331+
332+
file_hash = hash_file(file_path)
333+
334+
def chunk_to_meta(chunk: Chunk) -> chromadb.Metadata:
335+
meta: dict[str, int | str] = {"path": file_path, "sha256": file_hash}
336+
if chunk.start:
337+
meta["start"] = chunk.start.row
338+
339+
if chunk.end:
340+
meta["end"] = chunk.end.row
341+
return meta
342+
343+
async with Chroma0ClientManager().get_client(self._configs) as client:
344+
max_bs = await client.get_max_batch_size()
345+
for batch_start_idx in range(0, len(chunks), max_bs):
346+
batch_chunks = [
347+
chunks[i].text
348+
for i in range(
349+
batch_start_idx, min(batch_start_idx + max_bs, len(chunks))
350+
)
351+
]
352+
batch_embeddings = embeddings[
353+
batch_start_idx : batch_start_idx + max_bs
354+
]
355+
batch_meta = [
356+
chunk_to_meta(chunks[i])
357+
for i in range(
358+
batch_start_idx, min(batch_start_idx + max_bs, len(chunks))
359+
)
360+
]
332361
await collection.add(
333-
documents=[chunk.text for chunk in this_batch],
334-
ids=[get_uuid() for _ in range(max_batch_size)],
335-
embeddings=chunk_embeddings[idx_batch : idx_batch + max_batch_size],
336-
metadatas=metadatas,
362+
documents=batch_chunks,
363+
embeddings=batch_embeddings,
364+
metadatas=batch_meta,
365+
ids=[get_uuid() for _ in batch_chunks],
337366
)
338-
return stats
367+
return VectoriseStats(add=1)
339368

340369
async def list_collections(self):
341370
async with Chroma0ClientManager().get_client(
@@ -360,52 +389,51 @@ async def list_collections(self):
360389
)
361390
return result
362391

363-
async def list(self, collection_id, what=None) -> CollectionContent:
392+
async def list(self, collection_path, what=None) -> CollectionContent:
364393
"""
365394
When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`.
366395
Otherwise, this method may populate only one of them to save waiting time.
367396
"""
368397
content = CollectionContent()
369-
async with Chroma0ClientManager().get_client(
370-
configs=self._configs, need_lock=False
371-
) as client:
372-
collection = await client.get_collection(collection_id)
373-
raw_content = await collection.get(
374-
include=[
375-
IncludeEnum.metadatas,
376-
IncludeEnum.documents,
377-
]
378-
)
379-
metadatas = raw_content.get("metadatas", [])
380-
documents = raw_content.get("documents", [])
381-
ids = raw_content.get("ids", [])
382-
assert metadatas
383-
assert documents
384-
assert ids
385-
if what is None or what == ResultType.document:
386-
content.files.extend(
387-
set(
388-
FileInCollection(
389-
path=str(i.get("path")), sha256=str(i.get("sha256"))
390-
)
391-
for i in metadatas
398+
collection = await self._create_or_get_collection(
399+
get_collection_id(collection_path)
400+
)
401+
raw_content = await collection.get(
402+
include=[
403+
IncludeEnum.metadatas,
404+
IncludeEnum.documents,
405+
]
406+
)
407+
metadatas = raw_content.get("metadatas", [])
408+
documents = raw_content.get("documents", [])
409+
ids = raw_content.get("ids", [])
410+
assert metadatas
411+
assert documents
412+
assert ids
413+
if what is None or what == ResultType.document:
414+
content.files.extend(
415+
set(
416+
FileInCollection(
417+
path=str(i.get("path")), sha256=str(i.get("sha256"))
392418
)
419+
for i in metadatas
393420
)
394-
if what is None or what == ResultType.chunk:
395-
for i in range(len(ids)):
396-
start, end = None, None
397-
if metadatas[i].get("start") is not None:
398-
start = Point(row=int(metadatas[i]["start"]), column=0)
399-
if metadatas[i].get("end") is not None:
400-
end = Point(row=int(metadatas[i]["end"]), column=0)
401-
content.chunks.append(
402-
Chunk(
403-
text=documents[i],
404-
path=str(metadatas[i].get("path", "")) or None,
405-
id=ids[i],
406-
start=start,
407-
end=end,
408-
)
421+
)
422+
if what is None or what == ResultType.chunk:
423+
for i in range(len(ids)):
424+
start, end = None, None
425+
if metadatas[i].get("start") is not None:
426+
start = Point(row=int(metadatas[i]["start"]), column=0)
427+
if metadatas[i].get("end") is not None:
428+
end = Point(row=int(metadatas[i]["end"]), column=0)
429+
content.chunks.append(
430+
Chunk(
431+
text=documents[i],
432+
path=str(metadatas[i].get("path", "")) or None,
433+
id=ids[i],
434+
start=start,
435+
end=end,
409436
)
437+
)
410438

411439
return content

src/vectorcode/database/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,16 @@ def to_table(self) -> str:
4646
headers="firstrow",
4747
)
4848

49+
def __add__(self, other) -> "VectoriseStats":
50+
assert isinstance(other, VectoriseStats), (
51+
"`VectoriseStats` can only perform arithmatics with objects of the same type."
52+
)
53+
new = VectoriseStats()
54+
for f in fields(self):
55+
f_name = f.name
56+
setattr(new, f_name, sum(getattr(i, f_name) for i in (self, other)))
57+
return new
58+
4959

5060
@dataclass
5161
class CollectionInfo:

0 commit comments

Comments
 (0)