99
1010from vectorcode .chunking import TreeSitterChunker
1111from vectorcode .cli_utils import Config
12+ from vectorcode .common import get_embedding_function
1213from vectorcode .database .types import (
1314 CollectionContent ,
1415 CollectionInfo ,
@@ -64,7 +65,7 @@ async def count(self, what: ResultType = ResultType.chunk) -> int:
6465 """
6566 Returns the chunk count or file count of the given collection, depending on the value passed for `what`.
6667 """
67- collection_content = await self .list (what )
68+ collection_content = await self .list_collection_content (what )
6869 match what :
6970 case ResultType .chunk :
7071 return len (collection_content .chunks )
@@ -74,7 +75,6 @@ async def count(self, what: ResultType = ResultType.chunk) -> int:
7475 @abstractmethod
7576 async def query (
7677 self ,
77- keywords_embeddings : list [NDArray ],
7878 ) -> Sequence [QueryResult ]:
7979 pass
8080
@@ -99,7 +99,9 @@ async def list_collections(self) -> Sequence[CollectionInfo]:
9999 pass
100100
101101 @abstractmethod
102- async def list (self , what : Optional [ResultType ] = None ) -> CollectionContent :
102+ async def list_collection_content (
103+ self , what : Optional [ResultType ] = None
104+ ) -> CollectionContent :
103105 """
104106 List the content of a collection (from `self._configs.project_root`).
105107
@@ -163,7 +165,7 @@ async def check_orphanes(self) -> int:
163165 """
164166
165167 orphanes : list [str ] = []
166- database_files = (await self .list (ResultType .document )).files
168+ database_files = (await self .list_collection_content (ResultType .document )).files
167169 for file in database_files :
168170 path = file .path
169171 if not os .path .isfile (path ):
@@ -174,3 +176,14 @@ async def check_orphanes(self) -> int:
174176 await self .delete ()
175177
176178 return len (orphanes )
179+
180+ def get_embedding (self , texts : str | list [str ]) -> list [NDArray ]:
181+ """
182+ Generate embeddings and truncate them to `self._configs.embedding_dims` if needed.
183+ """
184+ if isinstance (texts , str ):
185+ texts = [texts ]
186+ embeddings = get_embedding_function (self ._configs )(texts )
187+ if self ._configs .embedding_dims :
188+ embeddings = [e [: self ._configs .embedding_dims ] for e in embeddings ]
189+ return embeddings
0 commit comments