88import sys
99from asyncio .subprocess import Process
1010from dataclasses import dataclass
11- from typing import Any , Optional
11+ from typing import Any , Optional , cast
1212from urllib .parse import urlparse
1313
1414import chromadb
1515import httpx
1616from chromadb .api import AsyncClientAPI
1717from chromadb .api .models .AsyncCollection import AsyncCollection
18- from chromadb .api .types import IncludeEnum , QueryResult
18+ from chromadb .api .types import EmbeddingFunction , IncludeEnum , QueryResult
1919from chromadb .config import APIVersion , Settings
2020from tree_sitter import Point
2121
2222import vectorcode .subcommands .query .types as vectorcode_query_types
23- from vectorcode .chunking import Chunk
23+ from vectorcode .chunking import Chunk , TreeSitterChunker
2424from vectorcode .cli_utils import Config , LockManager , expand_path
25+ from vectorcode .common import get_embedding_function
2526from vectorcode .database .base import DatabaseConnectorBase
2627from vectorcode .database .types import (
2728 CollectionContent ,
3031 ResultType ,
3132 VectoriseStats ,
3233)
34+ from vectorcode .database .utils import get_collection_id , hash_file
3335from vectorcode .subcommands .vectorise import get_uuid
3436
3537logger = logging .getLogger (name = __name__ )
@@ -248,33 +250,36 @@ def __init__(self, configs: Config):
248250 params .update (self ._configs .db_params )
249251 self ._configs .db_params = params
250252
251- async def query (self , collection_id , keywords_embeddings , opts ):
253+ async def query (self , collection_path , keywords_embeddings , opts ):
252254 assert len (opts .keywords ), "Keywords cannot be empty"
253255 assert len (keywords_embeddings ) == len (opts .keywords ), (
254256 "Number of embeddings must match number of keywords."
255257 )
256- async with Chroma0ClientManager ().get_client (self ._configs , False ) as client :
257- collection = await client .get_collection (collection_id )
258- query_count = opts .count or (
259- await self .count (collection_id , ResultType .chunk )
260- )
261- query_result = await collection .query (
262- query_embeddings = keywords_embeddings ,
263- include = [
264- IncludeEnum .metadatas ,
265- IncludeEnum .documents ,
266- IncludeEnum .distances ,
267- ],
268- n_results = query_count ,
269- )
270- return __convert_chroma_query_results (query_result , opts .keywords )
258+ collection : AsyncCollection = await self ._create_or_get_collection (
259+ collection_path = collection_path , allow_create = False
260+ )
261+ query_count = opts .count or (
262+ await self .count (collection_path , ResultType .chunk )
263+ )
264+ query_result = await collection .query (
265+ query_embeddings = keywords_embeddings ,
266+ include = [
267+ IncludeEnum .metadatas ,
268+ IncludeEnum .documents ,
269+ IncludeEnum .distances ,
270+ ],
271+ n_results = query_count ,
272+ )
273+ return __convert_chroma_query_results (query_result , opts .keywords )
271274
272- async def _create_or_get_collection (self , collection_id ) -> AsyncCollection :
275+ async def _create_or_get_collection (
276+ self , collection_path : str , allow_create : bool = False
277+ ) -> AsyncCollection :
273278 """
274279 This method should be used by ChromaDB methods that are expected to **create a collection when not found**.
275280 For other methods, just use `client.get_collection` and let it fail if the collection doesn't exist.
276281 """
277- assert self . _configs . project_root is not None
282+
278283 collection_meta : dict [str , str | int ] = {
279284 "path" : os .path .abspath (str (self ._configs .project_root )),
280285 "hostname" : socket .gethostname (),
@@ -292,6 +297,9 @@ async def _create_or_get_collection(self, collection_id) -> AsyncCollection:
292297 collection_meta [meta_field_name ] = db_params [key ]
293298
294299 async with Chroma0ClientManager ().get_client (self ._configs , True ) as client :
300+ collection_id = get_collection_id (collection_path )
301+ if not allow_create :
302+ return await client .get_collection (collection_id )
295303 col = await client .get_or_create_collection (
296304 collection_id , metadata = collection_meta
297305 )
@@ -303,39 +311,60 @@ async def _create_or_get_collection(self, collection_id) -> AsyncCollection:
303311
304312 return col
305313
306- async def vectorise (self , collection_id , chunks , chunk_embeddings , file_hashes ):
307- # WIP: finish the stats.
308- # should this method handle chunking and hash checking?
309- stats = VectoriseStats ()
310-
311- async with Chroma0ClientManager ().get_client (self ._configs , True ) as client :
312- collection = await self ._create_or_get_collection (collection_id )
313- max_batch_size = await client .get_max_batch_size ()
314- for idx_batch in range (0 , len (chunks ), max_batch_size ):
315- this_batch : list [Chunk ] = chunks [idx_batch : idx_batch + max_batch_size ]
316- metadatas = []
317- for idx in range (len (this_batch )):
318- this_chunk = this_batch [idx ]
319- meta : dict [str , str | int ] = {
320- "path" : str (this_chunk .path ),
321- }
322- if file_hashes :
323- for idx in range (idx_batch , idx_batch + max_batch_size ):
324- meta ["sha256" ] = file_hashes [idx ]
325- if this_chunk .start and isinstance (this_chunk .start .row , int ):
326- meta ["start" ] = this_chunk .start .row
327- if this_chunk .end and isinstance (this_chunk .end .row , int ):
328- meta ["end" ] = this_chunk .end .row
329-
330- metadatas .append (meta )
331-
314+ async def vectorise (
315+ self ,
316+ collection_path : str ,
317+ file_path : str ,
318+ chunker : TreeSitterChunker | None = None ,
319+ embedding_function : EmbeddingFunction | None = None ,
320+ ) -> VectoriseStats :
321+ collection = await self ._create_or_get_collection (
322+ collection_path , allow_create = True
323+ )
324+ chunker = chunker or TreeSitterChunker (self ._configs )
325+ embedding_function = cast (
326+ EmbeddingFunction ,
327+ embedding_function or get_embedding_function (self ._configs ),
328+ )
329+ chunks = tuple (chunker .chunk (file_path ))
330+ embeddings = embedding_function (list (i .text for i in chunks ))
331+
332+ file_hash = hash_file (file_path )
333+
334+ def chunk_to_meta (chunk : Chunk ) -> chromadb .Metadata :
335+ meta : dict [str , int | str ] = {"path" : file_path , "sha256" : file_hash }
336+ if chunk .start :
337+ meta ["start" ] = chunk .start .row
338+
339+ if chunk .end :
340+ meta ["end" ] = chunk .end .row
341+ return meta
342+
343+ async with Chroma0ClientManager ().get_client (self ._configs ) as client :
344+ max_bs = await client .get_max_batch_size ()
345+ for batch_start_idx in range (0 , len (chunks ), max_bs ):
346+ batch_chunks = [
347+ chunks [i ].text
348+ for i in range (
349+ batch_start_idx , min (batch_start_idx + max_bs , len (chunks ))
350+ )
351+ ]
352+ batch_embeddings = embeddings [
353+ batch_start_idx : batch_start_idx + max_bs
354+ ]
355+ batch_meta = [
356+ chunk_to_meta (chunks [i ])
357+ for i in range (
358+ batch_start_idx , min (batch_start_idx + max_bs , len (chunks ))
359+ )
360+ ]
332361 await collection .add (
333- documents = [ chunk . text for chunk in this_batch ] ,
334- ids = [ get_uuid () for _ in range ( max_batch_size )] ,
335- embeddings = chunk_embeddings [ idx_batch : idx_batch + max_batch_size ] ,
336- metadatas = metadatas ,
362+ documents = batch_chunks ,
363+ embeddings = batch_embeddings ,
364+ metadatas = batch_meta ,
365+ ids = [ get_uuid () for _ in batch_chunks ] ,
337366 )
338- return stats
367+ return VectoriseStats ( add = 1 )
339368
340369 async def list_collections (self ):
341370 async with Chroma0ClientManager ().get_client (
@@ -360,52 +389,51 @@ async def list_collections(self):
360389 )
361390 return result
362391
363- async def list (self , collection_id , what = None ) -> CollectionContent :
392+ async def list (self , collection_path , what = None ) -> CollectionContent :
364393 """
365394 When `what` is None, this method should populate both `CollectionContent.files` and `CollectionContent.chunks`.
366395 Otherwise, this method may populate only one of them to save waiting time.
367396 """
368397 content = CollectionContent ()
369- async with Chroma0ClientManager ().get_client (
370- configs = self ._configs , need_lock = False
371- ) as client :
372- collection = await client .get_collection (collection_id )
373- raw_content = await collection .get (
374- include = [
375- IncludeEnum .metadatas ,
376- IncludeEnum .documents ,
377- ]
378- )
379- metadatas = raw_content .get ("metadatas" , [])
380- documents = raw_content .get ("documents" , [])
381- ids = raw_content .get ("ids" , [])
382- assert metadatas
383- assert documents
384- assert ids
385- if what is None or what == ResultType .document :
386- content .files .extend (
387- set (
388- FileInCollection (
389- path = str (i .get ("path" )), sha256 = str (i .get ("sha256" ))
390- )
391- for i in metadatas
398+ collection = await self ._create_or_get_collection (
399+ get_collection_id (collection_path )
400+ )
401+ raw_content = await collection .get (
402+ include = [
403+ IncludeEnum .metadatas ,
404+ IncludeEnum .documents ,
405+ ]
406+ )
407+ metadatas = raw_content .get ("metadatas" , [])
408+ documents = raw_content .get ("documents" , [])
409+ ids = raw_content .get ("ids" , [])
410+ assert metadatas
411+ assert documents
412+ assert ids
413+ if what is None or what == ResultType .document :
414+ content .files .extend (
415+ set (
416+ FileInCollection (
417+ path = str (i .get ("path" )), sha256 = str (i .get ("sha256" ))
392418 )
419+ for i in metadatas
393420 )
394- if what is None or what == ResultType . chunk :
395- for i in range ( len ( ids )) :
396- start , end = None , None
397- if metadatas [ i ]. get ( " start" ) is not None :
398- start = Point ( row = int ( metadatas [i ][ "start" ]), column = 0 )
399- if metadatas [i ]. get ( "end" ) is not None :
400- end = Point ( row = int ( metadatas [i ][ "end" ]), column = 0 )
401- content . chunks . append (
402- Chunk (
403- text = documents [ i ],
404- path = str ( metadatas [i ]. get ( "path" , "" )) or None ,
405- id = ids [i ],
406- start = start ,
407- end = end ,
408- )
421+ )
422+ if what is None or what == ResultType . chunk :
423+ for i in range ( len ( ids )):
424+ start , end = None , None
425+ if metadatas [i ]. get ( "start" ) is not None :
426+ start = Point ( row = int ( metadatas [i ][ "start" ]), column = 0 )
427+ if metadatas [i ]. get ( "end" ) is not None :
428+ end = Point ( row = int ( metadatas [ i ][ "end" ]), column = 0 )
429+ content . chunks . append (
430+ Chunk (
431+ text = documents [i ],
432+ path = str ( metadatas [i ]. get ( "path" , "" )) or None ,
433+ id = ids [ i ] ,
434+ start = start ,
435+ end = end ,
409436 )
437+ )
410438
411439 return content
0 commit comments