1313import warnings # to warn if pymongo is missing
1414from contextlib import suppress
1515from datetime import datetime , timedelta
16- from inspect import isawaitable
1716from typing import Any , Optional , Tuple
1817
1918from .._types import HashFunc , Mongetter
@@ -68,16 +67,7 @@ def _ensure_collection(self) -> Any:
6867
6968 with self .lock :
7069 if self .mongo_collection is None :
71- coll = self .mongetter ()
72- if isawaitable (coll ):
73- # Avoid "coroutine was never awaited" warnings.
74- close = getattr (coll , "close" , None )
75- if callable (close ):
76- with suppress (Exception ):
77- close ()
78- msg = "async mongetter is only supported for async cached functions"
79- raise TypeError (msg )
80- self .mongo_collection = coll
70+ self .mongo_collection = self .mongetter ()
8171
8272 if not self ._index_verified :
8373 index_inf = self .mongo_collection .index_information ()
@@ -96,23 +86,17 @@ async def _ensure_collection_async(self) -> Any:
9686 if self .mongo_collection is not None and self ._index_verified :
9787 return self .mongo_collection
9888
99- coll = self .mongetter ()
100- if isawaitable (coll ):
101- coll = await coll
89+ coll = await self .mongetter ()
10290 self .mongo_collection = coll
10391
10492 if not self ._index_verified :
105- index_inf = self .mongo_collection .index_information ()
106- if isawaitable (index_inf ):
107- index_inf = await index_inf
93+ index_inf = await self .mongo_collection .index_information ()
10894 if _MongoCore ._INDEX_NAME not in index_inf :
10995 func1key1 = IndexModel (
11096 keys = [("func" , ASCENDING ), ("key" , ASCENDING )],
11197 name = _MongoCore ._INDEX_NAME ,
11298 )
113- res = self .mongo_collection .create_indexes ([func1key1 ])
114- if isawaitable (res ):
115- await res
99+ await self .mongo_collection .create_indexes ([func1key1 ])
116100 self ._index_verified = True
117101
118102 return self .mongo_collection
@@ -144,9 +128,7 @@ async def aget_entry(self, args, kwds) -> Tuple[str, Optional[CacheEntry]]:
144128
145129 async def aget_entry_by_key (self , key : str ) -> Tuple [str , Optional [CacheEntry ]]:
146130 mongo_collection = await self ._ensure_collection_async ()
147- res = mongo_collection .find_one ({"func" : self ._func_str , "key" : key })
148- if isawaitable (res ):
149- res = await res
131+ res = await mongo_collection .find_one ({"func" : self ._func_str , "key" : key })
150132 if not res :
151133 return key , None
152134 val = None
@@ -188,7 +170,7 @@ async def aset_entry(self, key: str, func_res: Any) -> bool:
188170 return False
189171 mongo_collection = await self ._ensure_collection_async ()
190172 thebytes = pickle .dumps (func_res )
191- res = mongo_collection .update_one (
173+ await mongo_collection .update_one (
192174 filter = {"func" : self ._func_str , "key" : key },
193175 update = {
194176 "$set" : {
@@ -203,8 +185,6 @@ async def aset_entry(self, key: str, func_res: Any) -> bool:
203185 },
204186 upsert = True ,
205187 )
206- if isawaitable (res ):
207- await res
208188 return True
209189
210190 def mark_entry_being_calculated (self , key : str ) -> None :
@@ -217,13 +197,11 @@ def mark_entry_being_calculated(self, key: str) -> None:
217197
218198 async def amark_entry_being_calculated (self , key : str ) -> None :
219199 mongo_collection = await self ._ensure_collection_async ()
220- res = mongo_collection .update_one (
200+ await mongo_collection .update_one (
221201 filter = {"func" : self ._func_str , "key" : key },
222202 update = {"$set" : {"processing" : True }},
223203 upsert = True ,
224204 )
225- if isawaitable (res ):
226- await res
227205
228206 def mark_entry_not_calculated (self , key : str ) -> None :
229207 mongo_collection = self ._ensure_collection ()
@@ -240,13 +218,11 @@ def mark_entry_not_calculated(self, key: str) -> None:
240218 async def amark_entry_not_calculated (self , key : str ) -> None :
241219 mongo_collection = await self ._ensure_collection_async ()
242220 with suppress (OperationFailure ):
243- res = mongo_collection .update_one (
221+ await mongo_collection .update_one (
244222 filter = {"func" : self ._func_str , "key" : key },
245223 update = {"$set" : {"processing" : False }},
246224 upsert = False ,
247225 )
248- if isawaitable (res ):
249- await res
250226
251227 def wait_on_entry_calc (self , key : str ) -> Any :
252228 time_spent = 0
@@ -266,9 +242,7 @@ def clear_cache(self) -> None:
266242
267243 async def aclear_cache (self ) -> None :
268244 mongo_collection = await self ._ensure_collection_async ()
269- res = mongo_collection .delete_many (filter = {"func" : self ._func_str })
270- if isawaitable (res ):
271- await res
245+ await mongo_collection .delete_many (filter = {"func" : self ._func_str })
272246
273247 def clear_being_calculated (self ) -> None :
274248 mongo_collection = self ._ensure_collection ()
@@ -279,12 +253,10 @@ def clear_being_calculated(self) -> None:
279253
280254 async def aclear_being_calculated (self ) -> None :
281255 mongo_collection = await self ._ensure_collection_async ()
282- res = mongo_collection .update_many (
256+ await mongo_collection .update_many (
283257 filter = {"func" : self ._func_str , "processing" : True },
284258 update = {"$set" : {"processing" : False }},
285259 )
286- if isawaitable (res ):
287- await res
288260
289261 def delete_stale_entries (self , stale_after : timedelta ) -> None :
290262 """Delete stale entries from the MongoDB cache."""
@@ -296,6 +268,4 @@ async def adelete_stale_entries(self, stale_after: timedelta) -> None:
296268 """Delete stale entries from the MongoDB cache."""
297269 mongo_collection = await self ._ensure_collection_async ()
298270 threshold = datetime .now () - stale_after
299- res = mongo_collection .delete_many (filter = {"func" : self ._func_str , "time" : {"$lt" : threshold }})
300- if isawaitable (res ):
301- await res
271+ await mongo_collection .delete_many (filter = {"func" : self ._func_str , "time" : {"$lt" : threshold }})
0 commit comments