Skip to content

Commit 4f25080

Browse files
committed
feat: Restrict decorating async methods with sync enignes of Redis, Mongo and SQL cores
1 parent 6414fb4 commit 4f25080

9 files changed

Lines changed: 632 additions & 314 deletions

File tree

src/cachier/core.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ def _pop_kwds_with_deprecation(kwds, name: str, default_value: bool):
156156
return kwds.pop(name, default_value)
157157

158158

159+
def _is_async_redis_client(client: Any) -> bool:
160+
if client is None:
161+
return False
162+
method_names = ("hgetall", "hset", "keys", "delete", "hget")
163+
return all(inspect.iscoroutinefunction(getattr(client, name, None)) for name in method_names)
164+
165+
159166
def cachier(
160167
hash_func: Optional[HashFunc] = None,
161168
hash_params: Optional[HashFunc] = None,
@@ -300,6 +307,42 @@ def cachier(
300307

301308
def _cachier_decorator(func):
302309
core.set_func(func)
310+
is_coroutine = inspect.iscoroutinefunction(func)
311+
312+
if backend == "mongo":
313+
if is_coroutine and not inspect.iscoroutinefunction(mongetter):
314+
msg = "Async cached functions with Mongo backend require an async mongetter."
315+
raise TypeError(msg)
316+
if (not is_coroutine) and inspect.iscoroutinefunction(mongetter):
317+
msg = "Async mongetter requires an async cached function."
318+
raise TypeError(msg)
319+
320+
if backend == "redis":
321+
if is_coroutine:
322+
if callable(redis_client):
323+
if not inspect.iscoroutinefunction(redis_client):
324+
msg = "Async cached functions with Redis backend require an async redis_client callable."
325+
raise TypeError(msg)
326+
elif not _is_async_redis_client(redis_client):
327+
msg = "Async cached functions with Redis backend require an async Redis client."
328+
raise TypeError(msg)
329+
else:
330+
if callable(redis_client) and inspect.iscoroutinefunction(redis_client):
331+
msg = "Async redis_client callable requires an async cached function."
332+
raise TypeError(msg)
333+
if _is_async_redis_client(redis_client):
334+
msg = "Async Redis client requires an async cached function."
335+
raise TypeError(msg)
336+
337+
if backend == "sql":
338+
sql_core = core
339+
assert isinstance(sql_core, _SQLCore) # noqa: S101
340+
if is_coroutine and not sql_core.has_async_engine():
341+
msg = "Async cached functions with SQL backend require an AsyncEngine sql_engine."
342+
raise TypeError(msg)
343+
if (not is_coroutine) and sql_core.has_async_engine():
344+
msg = "Async SQL engines require an async cached function."
345+
raise TypeError(msg)
303346

304347
last_cleanup = datetime.min
305348
cleanup_lock = threading.Lock()
@@ -501,8 +544,6 @@ async def _call_async(*args, max_age: Optional[timedelta] = None, **kwds):
501544
# argument.
502545
# For async functions, we create an async wrapper that calls
503546
# _call_async.
504-
is_coroutine = inspect.iscoroutinefunction(func)
505-
506547
if is_coroutine:
507548

508549
@wraps(func)
@@ -522,6 +563,14 @@ def _clear_being_calculated():
522563
"""Mark all entries in this cache as not being calculated."""
523564
core.clear_being_calculated()
524565

566+
async def _aclear_cache():
567+
"""Clear the cache asynchronously."""
568+
await core.aclear_cache()
569+
570+
async def _aclear_being_calculated():
571+
"""Mark all entries in this cache as not being calculated asynchronously."""
572+
await core.aclear_being_calculated()
573+
525574
def _cache_dpath():
526575
"""Return the path to the cache dir, if exists; None if not."""
527576
return getattr(core, "cache_dir", None)
@@ -541,6 +590,8 @@ def _precache_value(*args, value_to_cache, **kwds): # noqa: D417
541590

542591
func_wrapper.clear_cache = _clear_cache
543592
func_wrapper.clear_being_calculated = _clear_being_calculated
593+
func_wrapper.aclear_cache = _aclear_cache
594+
func_wrapper.aclear_being_calculated = _aclear_being_calculated
544595
func_wrapper.cache_dpath = _cache_dpath
545596
func_wrapper.precache_value = _precache_value
546597
return func_wrapper

src/cachier/cores/mongo.py

Lines changed: 11 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import warnings # to warn if pymongo is missing
1414
from contextlib import suppress
1515
from datetime import datetime, timedelta
16-
from inspect import isawaitable
1716
from typing import Any, Optional, Tuple
1817

1918
from .._types import HashFunc, Mongetter
@@ -68,16 +67,7 @@ def _ensure_collection(self) -> Any:
6867

6968
with self.lock:
7069
if self.mongo_collection is None:
71-
coll = self.mongetter()
72-
if isawaitable(coll):
73-
# Avoid "coroutine was never awaited" warnings.
74-
close = getattr(coll, "close", None)
75-
if callable(close):
76-
with suppress(Exception):
77-
close()
78-
msg = "async mongetter is only supported for async cached functions"
79-
raise TypeError(msg)
80-
self.mongo_collection = coll
70+
self.mongo_collection = self.mongetter()
8171

8272
if not self._index_verified:
8373
index_inf = self.mongo_collection.index_information()
@@ -96,23 +86,17 @@ async def _ensure_collection_async(self) -> Any:
9686
if self.mongo_collection is not None and self._index_verified:
9787
return self.mongo_collection
9888

99-
coll = self.mongetter()
100-
if isawaitable(coll):
101-
coll = await coll
89+
coll = await self.mongetter()
10290
self.mongo_collection = coll
10391

10492
if not self._index_verified:
105-
index_inf = self.mongo_collection.index_information()
106-
if isawaitable(index_inf):
107-
index_inf = await index_inf
93+
index_inf = await self.mongo_collection.index_information()
10894
if _MongoCore._INDEX_NAME not in index_inf:
10995
func1key1 = IndexModel(
11096
keys=[("func", ASCENDING), ("key", ASCENDING)],
11197
name=_MongoCore._INDEX_NAME,
11298
)
113-
res = self.mongo_collection.create_indexes([func1key1])
114-
if isawaitable(res):
115-
await res
99+
await self.mongo_collection.create_indexes([func1key1])
116100
self._index_verified = True
117101

118102
return self.mongo_collection
@@ -144,9 +128,7 @@ async def aget_entry(self, args, kwds) -> Tuple[str, Optional[CacheEntry]]:
144128

145129
async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
146130
mongo_collection = await self._ensure_collection_async()
147-
res = mongo_collection.find_one({"func": self._func_str, "key": key})
148-
if isawaitable(res):
149-
res = await res
131+
res = await mongo_collection.find_one({"func": self._func_str, "key": key})
150132
if not res:
151133
return key, None
152134
val = None
@@ -188,7 +170,7 @@ async def aset_entry(self, key: str, func_res: Any) -> bool:
188170
return False
189171
mongo_collection = await self._ensure_collection_async()
190172
thebytes = pickle.dumps(func_res)
191-
res = mongo_collection.update_one(
173+
await mongo_collection.update_one(
192174
filter={"func": self._func_str, "key": key},
193175
update={
194176
"$set": {
@@ -203,8 +185,6 @@ async def aset_entry(self, key: str, func_res: Any) -> bool:
203185
},
204186
upsert=True,
205187
)
206-
if isawaitable(res):
207-
await res
208188
return True
209189

210190
def mark_entry_being_calculated(self, key: str) -> None:
@@ -217,13 +197,11 @@ def mark_entry_being_calculated(self, key: str) -> None:
217197

218198
async def amark_entry_being_calculated(self, key: str) -> None:
219199
mongo_collection = await self._ensure_collection_async()
220-
res = mongo_collection.update_one(
200+
await mongo_collection.update_one(
221201
filter={"func": self._func_str, "key": key},
222202
update={"$set": {"processing": True}},
223203
upsert=True,
224204
)
225-
if isawaitable(res):
226-
await res
227205

228206
def mark_entry_not_calculated(self, key: str) -> None:
229207
mongo_collection = self._ensure_collection()
@@ -240,13 +218,11 @@ def mark_entry_not_calculated(self, key: str) -> None:
240218
async def amark_entry_not_calculated(self, key: str) -> None:
241219
mongo_collection = await self._ensure_collection_async()
242220
with suppress(OperationFailure):
243-
res = mongo_collection.update_one(
221+
await mongo_collection.update_one(
244222
filter={"func": self._func_str, "key": key},
245223
update={"$set": {"processing": False}},
246224
upsert=False,
247225
)
248-
if isawaitable(res):
249-
await res
250226

251227
def wait_on_entry_calc(self, key: str) -> Any:
252228
time_spent = 0
@@ -266,9 +242,7 @@ def clear_cache(self) -> None:
266242

267243
async def aclear_cache(self) -> None:
268244
mongo_collection = await self._ensure_collection_async()
269-
res = mongo_collection.delete_many(filter={"func": self._func_str})
270-
if isawaitable(res):
271-
await res
245+
await mongo_collection.delete_many(filter={"func": self._func_str})
272246

273247
def clear_being_calculated(self) -> None:
274248
mongo_collection = self._ensure_collection()
@@ -279,12 +253,10 @@ def clear_being_calculated(self) -> None:
279253

280254
async def aclear_being_calculated(self) -> None:
281255
mongo_collection = await self._ensure_collection_async()
282-
res = mongo_collection.update_many(
256+
await mongo_collection.update_many(
283257
filter={"func": self._func_str, "processing": True},
284258
update={"$set": {"processing": False}},
285259
)
286-
if isawaitable(res):
287-
await res
288260

289261
def delete_stale_entries(self, stale_after: timedelta) -> None:
290262
"""Delete stale entries from the MongoDB cache."""
@@ -296,6 +268,4 @@ async def adelete_stale_entries(self, stale_after: timedelta) -> None:
296268
"""Delete stale entries from the MongoDB cache."""
297269
mongo_collection = await self._ensure_collection_async()
298270
threshold = datetime.now() - stale_after
299-
res = mongo_collection.delete_many(filter={"func": self._func_str, "time": {"$lt": threshold}})
300-
if isawaitable(res):
301-
await res
271+
await mongo_collection.delete_many(filter={"func": self._func_str, "time": {"$lt": threshold}})

0 commit comments

Comments
 (0)