2828QDRANT_BATCH_SIZE = 100
2929
3030
31- def qdrant_collection_exists (client , collection_name : str ) -> bool :
31+ def qdrant_collection_exists (client : QdrantClient , collection_name : str ) -> bool :
3232 collection_exists = True
33-
33+
3434 try :
3535 client .get_collection (collection_name )
36- except Exception as e :
36+ except Exception :
3737 collection_exists = False
38-
38+
3939 return collection_exists
40-
40+
41+
4142class QdrantLocal (VectorDB ):
4243 def __init__ (
4344 self ,
4445 dim : int ,
4546 db_config : dict ,
46- db_case_config : dict ,
47+ db_case_config : QdrantLocalIndexConfig ,
4748 collection_name : str = "QdrantLocalCollection" ,
4849 drop_old : bool = False ,
4950 name : str = "QdrantLocal" ,
@@ -56,26 +57,26 @@ def __init__(
5657 self .search_parameter = self .case_config .search_param ()
5758 self .collection_name = collection_name
5859 self .client = None
59-
60+
6061 self ._primary_field = "pk"
6162 self ._vector_field = "vector"
62-
63+
6364 client = QdrantClient (** self .db_config )
64-
65+
6566 # Lets just print the parameters here for double check
6667 log .info (f"Case config: { self .case_config .index_param ()} " )
6768 log .info (f"Search parameter: { self .search_parameter } " )
68-
69+
6970 if drop_old and qdrant_collection_exists (client , self .collection_name ):
7071 log .info (f"{ self .name } client drop_old collection: { self .collection_name } " )
7172 client .delete_collection (self .collection_name )
72-
73+
7374 if not qdrant_collection_exists (client , self .collection_name ):
7475 log .info (f"{ self .name } create collection: { self .collection_name } " )
7576 self ._create_collection (dim , client )
7677
7778 client = None
78-
79+
7980 @contextmanager
8081 def init (self ):
8182 """
@@ -89,11 +90,15 @@ def init(self):
8990 yield
9091 self .client = None
9192 del self .client
92-
93+
9394 def _create_collection (self , dim : int , qdrant_client : QdrantClient ):
9495 log .info (f"Create collection: { self .collection_name } " )
95- log .info (f"Index parameters: m={ self .case_config .index_param ()['m' ]} , ef_construct={ self .case_config .index_param ()['ef_construct' ]} , on_disk={ self .case_config .index_param ()['on_disk' ]} " )
96-
96+ log .info (
97+ f"Index parameters: m={ self .case_config .index_param ()['m' ]} , "
98+ f"ef_construct={ self .case_config .index_param ()['ef_construct' ]} , "
99+ f"on_disk={ self .case_config .index_param ()['on_disk' ]} "
100+ )
101+
97102 # If the on_disk is true, we enable both on disk index and vectors.
98103 try :
99104 qdrant_client .create_collection (
@@ -104,10 +109,10 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
104109 on_disk = self .case_config .index_param ()["on_disk" ],
105110 ),
106111 hnsw_config = HnswConfigDiff (
107- m = self .case_config .index_param ()["m" ],
112+ m = self .case_config .index_param ()["m" ],
108113 ef_construct = self .case_config .index_param ()["ef_construct" ],
109114 on_disk = self .case_config .index_param ()["on_disk" ],
110- )
115+ ),
111116 )
112117
113118 qdrant_client .create_payload_index (
@@ -121,7 +126,7 @@ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
121126 return
122127 log .warning (f"Failed to create collection: { self .collection_name } error: { e } " )
123128 raise e from None
124-
129+
125130 def optimize (self , data_size : int | None = None ):
126131 assert self .client , "Please call self.init() before"
127132 # wait for vectors to be fully indexed
@@ -139,11 +144,11 @@ def optimize(self, data_size: int | None = None):
139144 )
140145 log .info (msg )
141146 return
142-
147+
143148 except Exception as e :
144149 log .warning (f"QdrantCloud ready to search error: { e } " )
145150 raise e from None
146-
151+
147152 def insert_embeddings (
148153 self ,
149154 embeddings : Iterable [list [float ]],
@@ -163,7 +168,7 @@ def insert_embeddings(
163168 assert self .client is not None
164169 assert len (embeddings ) == len (metadata )
165170 insert_count = 0
166-
171+
167172 # disable indexing for quick insertion
168173 self .client .update_collection (
169174 collection_name = self .collection_name ,
@@ -185,13 +190,13 @@ def insert_embeddings(
185190 collection_name = self .collection_name ,
186191 optimizer_config = OptimizersConfigDiff (indexing_threshold = 100 ),
187192 )
188-
193+
189194 except Exception as e :
190195 log .info (f"Failed to insert data, { e } " )
191196 return insert_count , e
192197 else :
193198 return insert_count , None
194-
199+
195200 def search_embedding (
196201 self ,
197202 query : list [float ],
@@ -203,7 +208,7 @@ def search_embedding(
203208 Should call self.init() first.
204209 """
205210 assert self .client is not None
206-
211+
207212 f = None
208213 if filters :
209214 f = Filter (
@@ -215,17 +220,13 @@ def search_embedding(
215220 ),
216221 ),
217222 ],
218- )
219- res = (
220- self .client .query_points (
221- collection_name = self .collection_name ,
222- query = query ,
223- limit = k ,
224- query_filter = f ,
225- search_params = SearchParams (** self .search_parameter ),
226-
227- ).points
228- )
229-
230- return [result .id for result in res ]
223+ )
224+ res = self .client .query_points (
225+ collection_name = self .collection_name ,
226+ query = query ,
227+ limit = k ,
228+ query_filter = f ,
229+ search_params = SearchParams (** self .search_parameter ),
230+ ).points
231231
232+ return [result .id for result in res ]
0 commit comments