@@ -117,13 +117,6 @@ def __del__(self) -> None:
117117 if self ._connection :
118118 self ._connection .close ()
119119
120- async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None :
121- """
122- Asynchronous exit method to close MongoDB connections when the instance is destroyed.
123- """
124- if self ._connection_async :
125- await self ._connection_async .close ()
126-
127120 @property
128121 def connection (self ) -> Union [AsyncMongoClient , MongoClient ]:
129122 if self ._connection :
@@ -142,53 +135,51 @@ def collection(self) -> Union[AsyncCollection, Collection]:
142135 msg = "The collection is not established yet."
143136 raise DocumentStoreError (msg )
144137
145- def _connection_is_valid (self ) -> bool :
138+ def _connection_is_valid (self , connection : MongoClient ) -> bool :
146139 """
147140 Checks if the connection to MongoDB Atlas is valid.
148141
149142 :returns: True if the connection is valid, False otherwise.
150143 """
151144 try :
152- self . _connection . admin .command ("ping" ) # type: ignore[union-attr]
145+ connection . admin .command ("ping" )
153146 return True
154147 except Exception as e :
155148 logger .error (f"Connection to MongoDB Atlas failed: { e } " )
156149 return False
157150
158- async def _connection_is_valid_async (self ) -> bool :
151+ async def _connection_is_valid_async (self , connection : AsyncMongoClient ) -> bool :
159152 """
160153 Asynchronously checks if the connection to MongoDB Atlas is valid.
161154
162155 :returns: True if the connection is valid, False otherwise.
163156 """
164157 try :
165- await self . _connection_async . admin .command ("ping" ) # type: ignore[union-attr]
158+ await connection . admin .command ("ping" )
166159 return True
167160 except Exception as e :
168161 logger .error (f"Connection to MongoDB Atlas failed: { e } " )
169162 return False
170163
171- def _collection_exists (self ) -> bool :
164+ def _collection_exists (self , connection : MongoClient , database_name : str , collection_name : str ) -> bool :
172165 """
173166 Checks if the collection exists in the MongoDB Atlas database.
174167
175168 :returns: True if the collection exists, False otherwise.
176169 """
177- database = self ._connection [self .database_name ] # type: ignore[index]
178- if self .collection_name in database .list_collection_names ():
179- return True
180- return False
170+ database = connection [database_name ]
171+ return collection_name in database .list_collection_names ()
181172
182- async def _collection_exists_async (self ) -> bool :
173+ async def _collection_exists_async (
174+ self , connection : AsyncMongoClient , database_name : str , collection_name : str
175+ ) -> bool :
183176 """
184177 Asynchronously checks if the collection exists in the MongoDB Atlas database.
185178
186179 :returns: True if the collection exists, False otherwise.
187180 """
188- database = self ._connection_async [self .database_name ] # type: ignore[index]
189- if self .collection_name in await database .list_collection_names ():
190- return True
191- return False
181+ database = connection [database_name ]
182+ return collection_name in await database .list_collection_names ()
192183
193184 def _ensure_connection_setup (self ) -> None :
194185 """
@@ -202,11 +193,11 @@ def _ensure_connection_setup(self) -> None:
202193 self .mongo_connection_string .resolve_value (), driver = DriverInfo (name = "MongoDBAtlasHaystackIntegration" )
203194 )
204195
205- if not self ._connection_is_valid ():
196+ if not self ._connection_is_valid (self . _connection ):
206197 msg = "Connection to MongoDB Atlas failed."
207198 raise DocumentStoreError (msg )
208199
209- if not self ._collection_exists ():
200+ if not self ._collection_exists (self . _connection , self . database_name , self . collection_name ):
210201 msg = f"Collection '{ self .collection_name } ' does not exist in database '{ self .database_name } '."
211202 raise DocumentStoreError (msg )
212203
@@ -226,11 +217,11 @@ async def _ensure_connection_setup_async(self) -> None:
226217 self .mongo_connection_string .resolve_value (), driver = DriverInfo (name = "MongoDBAtlasHaystackIntegration" )
227218 )
228219
229- if not await self ._connection_is_valid_async ():
220+ if not await self ._connection_is_valid_async (self . _connection_async ):
230221 msg = "Connection to MongoDB Atlas failed."
231222 raise DocumentStoreError (msg )
232223
233- if not await self ._collection_exists_async ():
224+ if not await self ._collection_exists_async (self . _connection_async , self . database_name , self . collection_name ):
234225 msg = f"Collection '{ self .collection_name } ' does not exist in database '{ self .database_name } '."
235226 raise DocumentStoreError (msg )
236227
@@ -274,7 +265,8 @@ def count_documents(self) -> int:
274265 :returns: The number of documents in the document store.
275266 """
276267 self ._ensure_connection_setup ()
277- return self ._collection .count_documents ({}) # type: ignore[union-attr]
268+ assert self ._collection is not None
269+ return self ._collection .count_documents ({})
278270
279271 async def count_documents_async (self ) -> int :
280272 """
@@ -283,7 +275,8 @@ async def count_documents_async(self) -> int:
283275 :returns: The number of documents in the document store.
284276 """
285277 await self ._ensure_connection_setup_async ()
286- return await self ._collection_async .count_documents ({}) # type: ignore[union-attr]
278+ assert self ._collection_async is not None
279+ return await self ._collection_async .count_documents ({})
287280
288281 def filter_documents (self , filters : Optional [Dict [str , Any ]] = None ) -> List [Document ]:
289282 """
@@ -296,8 +289,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
296289 :returns: A list of Documents that match the given filters.
297290 """
298291 self ._ensure_connection_setup ()
292+ assert self ._collection is not None
299293 filters = _normalize_filters (filters ) if filters else None
300- documents = list (self ._collection .find (filters )) # type: ignore[union-attr]
294+ documents = list (self ._collection .find (filters ))
301295 return [self ._mongo_doc_to_haystack_doc (doc ) for doc in documents ]
302296
303297 async def filter_documents_async (self , filters : Optional [Dict [str , Any ]] = None ) -> List [Document ]:
@@ -311,8 +305,9 @@ async def filter_documents_async(self, filters: Optional[Dict[str, Any]] = None)
311305 :returns: A list of Documents that match the given filters.
312306 """
313307 await self ._ensure_connection_setup_async ()
308+ assert self ._collection_async is not None
314309 filters = _normalize_filters (filters ) if filters else None
315- documents = await self ._collection_async .find (filters ).to_list () # type: ignore[union-attr]
310+ documents = await self ._collection_async .find (filters ).to_list ()
316311 return [self ._mongo_doc_to_haystack_doc (doc ) for doc in documents ]
317312
318313 def write_documents (self , documents : List [Document ], policy : DuplicatePolicy = DuplicatePolicy .NONE ) -> int :
@@ -327,7 +322,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
327322 :returns: The number of documents written to the document store.
328323 """
329324 self ._ensure_connection_setup ()
330-
325+ assert self . _collection is not None
331326 if len (documents ) > 0 :
332327 if not isinstance (documents [0 ], Document ):
333328 msg = "param 'documents' must contain a list of objects of type Document"
@@ -342,15 +337,15 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
342337
343338 if policy == DuplicatePolicy .SKIP :
344339 operations = [UpdateOne ({"id" : doc ["id" ]}, {"$setOnInsert" : doc }, upsert = True ) for doc in mongo_documents ]
345- existing_documents = self ._collection .count_documents ({"id" : {"$in" : [doc .id for doc in documents ]}}) # type: ignore[union-attr]
340+ existing_documents = self ._collection .count_documents ({"id" : {"$in" : [doc .id for doc in documents ]}})
346341 written_docs -= existing_documents
347342 elif policy == DuplicatePolicy .FAIL :
348343 operations = [InsertOne (doc ) for doc in mongo_documents ]
349344 else :
350345 operations = [ReplaceOne ({"id" : doc ["id" ]}, upsert = True , replacement = doc ) for doc in mongo_documents ]
351346
352347 try :
353- self ._collection .bulk_write (operations ) # type: ignore[union-attr]
348+ self ._collection .bulk_write (operations )
354349 except BulkWriteError as e :
355350 msg = f"Duplicate documents found: { e .details ['writeErrors' ]} "
356351 raise DuplicateDocumentError (msg ) from e
@@ -371,7 +366,7 @@ async def write_documents_async(
371366 :returns: The number of documents written to the document store.
372367 """
373368 await self ._ensure_connection_setup_async ()
374-
369+ assert self . _collection_async is not None
375370 if len (documents ) > 0 :
376371 if not isinstance (documents [0 ], Document ):
377372 msg = "param 'documents' must contain a list of objects of type Document"
@@ -387,15 +382,17 @@ async def write_documents_async(
387382
388383 if policy == DuplicatePolicy .SKIP :
389384 operations = [UpdateOne ({"id" : doc ["id" ]}, {"$setOnInsert" : doc }, upsert = True ) for doc in mongo_documents ]
390- existing_documents = self ._collection .count_documents ({"id" : {"$in" : [doc .id for doc in documents ]}}) # type: ignore[union-attr]
385+ existing_documents = await self ._collection_async .count_documents (
386+ {"id" : {"$in" : [doc .id for doc in documents ]}}
387+ )
391388 written_docs -= existing_documents
392389 elif policy == DuplicatePolicy .FAIL :
393390 operations = [InsertOne (doc ) for doc in mongo_documents ]
394391 else :
395392 operations = [ReplaceOne ({"id" : doc ["id" ]}, upsert = True , replacement = doc ) for doc in mongo_documents ]
396393
397394 try :
398- await self ._collection_async .bulk_write (operations ) # type: ignore[union-attr]
395+ await self ._collection_async .bulk_write (operations )
399396 except BulkWriteError as e :
400397 msg = f"Duplicate documents found: { e .details ['writeErrors' ]} "
401398 raise DuplicateDocumentError (msg ) from e
@@ -409,9 +406,10 @@ def delete_documents(self, document_ids: List[str]) -> None:
409406 :param document_ids: the document ids to delete
410407 """
411408 self ._ensure_connection_setup ()
409+ assert self ._collection is not None
412410 if not document_ids :
413411 return
414- self ._collection .delete_many (filter = {"id" : {"$in" : document_ids }}) # type: ignore[union-attr]
412+ self ._collection .delete_many (filter = {"id" : {"$in" : document_ids }})
415413
416414 async def delete_documents_async (self , document_ids : List [str ]) -> None :
417415 """
@@ -420,9 +418,10 @@ async def delete_documents_async(self, document_ids: List[str]) -> None:
420418 :param document_ids: the document ids to delete
421419 """
422420 await self ._ensure_connection_setup_async ()
421+ assert self ._collection_async is not None
423422 if not document_ids :
424423 return
425- await self ._collection_async .delete_many (filter = {"id" : {"$in" : document_ids }}) # type: ignore[union-attr]
424+ await self ._collection_async .delete_many (filter = {"id" : {"$in" : document_ids }})
426425
427426 def _embedding_retrieval (
428427 self ,
@@ -441,6 +440,7 @@ def _embedding_retrieval(
441440 :raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
442441 """
443442 self ._ensure_connection_setup ()
443+ assert self ._collection is not None
444444 if not query_embedding :
445445 msg = "Query embedding must not be empty"
446446 raise ValueError (msg )
@@ -462,7 +462,7 @@ def _embedding_retrieval(
462462 {"$project" : {"_id" : 0 }},
463463 ]
464464 try :
465- documents = list (self ._collection .aggregate (pipeline )) # type: ignore[union-attr]
465+ documents = list (self ._collection .aggregate (pipeline ))
466466 except Exception as e :
467467 msg = f"Retrieval of documents from MongoDB Atlas failed: { e } "
468468 if filters :
@@ -490,6 +490,7 @@ async def _embedding_retrieval_async(
490490 :raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
491491 """
492492 await self ._ensure_connection_setup_async ()
493+ assert self ._collection_async is not None
493494 if not query_embedding :
494495 msg = "Query embedding must not be empty"
495496 raise ValueError (msg )
@@ -511,7 +512,8 @@ async def _embedding_retrieval_async(
511512 {"$project" : {"_id" : 0 }},
512513 ]
513514 try :
514- documents = await self ._collection_async .aggregate (pipeline ).to_list () # type: ignore[union-attr]
515+ cursor = await self ._collection_async .aggregate (pipeline )
516+ documents = await cursor .to_list (length = None )
515517 except Exception as e :
516518 msg = f"Retrieval of documents from MongoDB Atlas failed: { e } "
517519 if filters :
@@ -606,8 +608,9 @@ def _fulltext_retrieval(
606608 ]
607609
608610 self ._ensure_connection_setup ()
611+ assert self ._collection is not None
609612 try :
610- documents = list (self ._collection .aggregate (pipeline )) # type: ignore[union-attr]
613+ documents = list (self ._collection .aggregate (pipeline ))
611614 except Exception as e :
612615 error_msg = f"Failed to retrieve documents from MongoDB Atlas: { e } "
613616 if filters :
@@ -698,9 +701,9 @@ async def _fulltext_retrieval_async(
698701 ]
699702
700703 await self ._ensure_connection_setup_async ()
701-
704+ assert self . _collection_async is not None
702705 try :
703- cursor = await self ._collection_async .aggregate (pipeline ) # type: ignore[union-attr]
706+ cursor = await self ._collection_async .aggregate (pipeline )
704707 documents = await cursor .to_list (length = None )
705708 except Exception as e :
706709 error_msg = f"Failed to retrieve documents from MongoDB Atlas: { e } "
0 commit comments