Skip to content

Commit 37ac0b1

Browse files
committed
Add argument-specific cache clearing
1 parent 422f890 commit 37ac0b1

16 files changed

Lines changed: 294 additions & 8 deletions

File tree

README.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ The Cachier wrapper adds a ``clear_cache()`` function to each wrapped function.
133133
134134
foo.clear_cache()
135135
136+
To clear only the cache entry for a specific call, pass the same arguments to ``clear_cache()`` that you would pass to the wrapped function:
137+
138+
.. code-block:: python
139+
140+
foo.clear_cache(arg1, arg2)
141+
foo.clear_cache(arg1, arg2=arg2)
142+
143+
The asynchronous ``aclear_cache()`` helper supports the same argument-specific form.
144+
136145
General Configuration
137146
----------------------
138147

src/cachier/core.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class _CachierWrappedFunc(Protocol[_P, _R_co]):
4848

4949
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R_co: ... # pragma: no cover
5050

51-
clear_cache: Callable[[], Any]
51+
clear_cache: Callable[..., Any]
5252
clear_being_calculated: Callable[[], Any]
53-
aclear_cache: Callable[[], Any]
53+
aclear_cache: Callable[..., Any]
5454
aclear_being_calculated: Callable[[], Any]
5555
cache_dpath: Callable[[], Optional[str]]
5656
precache_value: Callable[..., Any]
@@ -219,6 +219,13 @@ def _is_async_redis_client(client: Any) -> bool:
219219
return all(inspect.iscoroutinefunction(getattr(client, name, None)) for name in method_names)
220220

221221

222+
def _convert_public_cache_args(func, _is_method: bool, args: tuple, kwds: dict) -> dict:
223+
"""Convert cache-management arguments to canonical cache-key kwargs."""
224+
if _is_method:
225+
args = (None, *args)
226+
return _convert_args_kwargs(func, _is_method=_is_method, args=args, kwds=kwds)
227+
228+
222229
def cachier(
223230
hash_func: Optional[HashFunc] = None,
224231
hash_params: Optional[HashFunc] = None,
@@ -733,9 +740,14 @@ async def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
733740
def func_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
734741
return _call(*args, **kwargs) # type: ignore[arg-type]
735742

736-
def _clear_cache():
737-
"""Clear the cache."""
738-
core.clear_cache()
743+
def _clear_cache(*args, **kwds):
744+
"""Clear the cache, or only the entry matching the provided arguments."""
745+
if args or kwds:
746+
kwargs = _convert_public_cache_args(func, core.func_is_method, args, kwds)
747+
key = core.get_key((), kwargs)
748+
core.clear_cache_entry(key)
749+
else:
750+
core.clear_cache()
739751
if is_coroutine:
740752
return _ImmediateAwaitable()
741753
return None
@@ -747,9 +759,14 @@ def _clear_being_calculated():
747759
return _ImmediateAwaitable()
748760
return None
749761

750-
async def _aclear_cache():
751-
"""Clear the cache asynchronously."""
752-
await core.aclear_cache()
762+
async def _aclear_cache(*args, **kwds):
763+
"""Clear the cache asynchronously, or only the entry matching the provided arguments."""
764+
if args or kwds:
765+
kwargs = _convert_public_cache_args(func, core.func_is_method, args, kwds)
766+
key = core.get_key((), kwargs)
767+
await core.aclear_cache_entry(key)
768+
else:
769+
await core.aclear_cache()
753770

754771
async def _aclear_being_calculated():
755772
"""Mark all entries in this cache as not being calculated asynchronously."""

src/cachier/cores/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,18 @@ async def aclear_cache(self) -> None:
274274
"""
275275
await asyncio.to_thread(self.clear_cache)
276276

277+
@abc.abstractmethod
278+
def clear_cache_entry(self, key: str) -> None:
279+
"""Clear the cache entry mapped by the given key."""
280+
281+
async def aclear_cache_entry(self, key: str) -> None:
282+
"""Async-compatible variant of :meth:`clear_cache_entry`.
283+
284+
By default this runs in a thread to avoid blocking the event loop.
285+
286+
"""
287+
await asyncio.to_thread(self.clear_cache_entry, key)
288+
277289
@abc.abstractmethod
278290
def clear_being_calculated(self) -> None:
279291
"""Mark all entries in this cache as not being calculated."""

src/cachier/cores/memory.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def clear_cache(self) -> None:
126126
# Update size metrics after clearing
127127
self._update_size_metrics()
128128

129+
def clear_cache_entry(self, key: str) -> None:
130+
with self.lock:
131+
self.cache.pop(self._hash_func_key(key), None)
132+
self._update_size_metrics()
133+
129134
def clear_being_calculated(self) -> None:
130135
with self.lock:
131136
for entry in self.cache.values():

src/cachier/cores/mongo.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,14 @@ async def aclear_cache(self) -> None:
249249
mongo_collection = await self._ensure_collection_async()
250250
await mongo_collection.delete_many(filter={"func": self._func_str})
251251

252+
def clear_cache_entry(self, key: str) -> None:
253+
mongo_collection = self._ensure_collection()
254+
mongo_collection.delete_one(filter={"func": self._func_str, "key": key})
255+
256+
async def aclear_cache_entry(self, key: str) -> None:
257+
mongo_collection = await self._ensure_collection_async()
258+
await mongo_collection.delete_one(filter={"func": self._func_str, "key": key})
259+
252260
def clear_being_calculated(self) -> None:
253261
mongo_collection = self._ensure_collection()
254262
mongo_collection.update_many(

src/cachier/cores/pickle.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,17 @@ def clear_cache(self) -> None:
416416
else:
417417
self._save_cache({})
418418

419+
def clear_cache_entry(self, key: str) -> None:
420+
if self.separate_files:
421+
with suppress(FileNotFoundError):
422+
os.remove(f"{self.cache_fpath}_{key}")
423+
return
424+
425+
with self.lock:
426+
cache = self.get_cache_dict()
427+
cache.pop(key, None)
428+
self._save_cache(cache)
429+
419430
def clear_being_calculated(self) -> None:
420431
if self.separate_files:
421432
self._clear_being_calculated_all_cache_files()

src/cachier/cores/redis.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,16 @@ def clear_cache(self) -> None:
353353
except Exception as e:
354354
warnings.warn(f"Redis clear_cache failed: {e}", stacklevel=2)
355355

356+
def clear_cache_entry(self, key: str) -> None:
357+
"""Clear the cache entry mapped by the given key."""
358+
redis_client = self._resolve_redis_client()
359+
redis_key = self._get_redis_key(key)
360+
361+
try:
362+
redis_client.delete(redis_key)
363+
except Exception as e:
364+
warnings.warn(f"Redis clear_cache_entry failed: {e}", stacklevel=2)
365+
356366
async def aclear_cache(self) -> None:
357367
"""Clear the cache of this core asynchronously."""
358368
redis_client = await self._resolve_redis_client_async()
@@ -365,6 +375,16 @@ async def aclear_cache(self) -> None:
365375
except Exception as e:
366376
warnings.warn(f"Redis clear_cache failed: {e}", stacklevel=2)
367377

378+
async def aclear_cache_entry(self, key: str) -> None:
379+
"""Clear the cache entry mapped by the given key asynchronously."""
380+
redis_client = await self._resolve_redis_client_async()
381+
redis_key = self._get_redis_key(key)
382+
383+
try:
384+
await redis_client.delete(redis_key)
385+
except Exception as e:
386+
warnings.warn(f"Redis clear_cache_entry failed: {e}", stacklevel=2)
387+
368388
def clear_being_calculated(self) -> None:
369389
"""Mark all entries in this cache as not being calculated."""
370390
redis_client = self._resolve_redis_client()

src/cachier/cores/s3.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,15 @@ def clear_cache(self) -> None:
333333
except Exception as exc:
334334
_safe_warn(f"S3 clear_cache failed: {exc}")
335335

336+
def clear_cache_entry(self, key: str) -> None:
337+
"""Delete the cache entry mapped by the given key from S3."""
338+
client = self._get_s3_client()
339+
s3_key = self._get_s3_key(key)
340+
try:
341+
client.delete_object(Bucket=self.s3_bucket, Key=s3_key)
342+
except Exception as exc:
343+
_safe_warn(f"S3 clear_cache_entry failed: {exc}")
344+
336345
def clear_being_calculated(self) -> None:
337346
"""Reset the ``_processing`` flag on all entries for this function in S3."""
338347
client = self._get_s3_client()

src/cachier/cores/sql.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,12 +434,28 @@ def clear_cache(self) -> None:
434434
session.execute(delete(CacheTable).where(CacheTable.function_id == self._func_str))
435435
session.commit()
436436

437+
def clear_cache_entry(self, key: str) -> None:
438+
session_factory = self._get_sync_session()
439+
with self._lock, session_factory() as session:
440+
session.execute(
441+
delete(CacheTable).where(and_(CacheTable.function_id == self._func_str, CacheTable.key == key))
442+
)
443+
session.commit()
444+
437445
async def aclear_cache(self) -> None:
438446
session_factory = await self._get_async_session()
439447
async with session_factory() as session:
440448
await session.execute(delete(CacheTable).where(CacheTable.function_id == self._func_str))
441449
await session.commit()
442450

451+
async def aclear_cache_entry(self, key: str) -> None:
452+
session_factory = await self._get_async_session()
453+
async with session_factory() as session:
454+
await session.execute(
455+
delete(CacheTable).where(and_(CacheTable.function_id == self._func_str, CacheTable.key == key))
456+
)
457+
await session.commit()
458+
443459
def clear_being_calculated(self) -> None:
444460
session_factory = self._get_sync_session()
445461
with self._lock, session_factory() as session:

tests/mongo_tests/test_mongo_core.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,28 @@ def _test_mongo_caching(arg_1, arg_2):
139139
assert val6 == val5
140140

141141

142+
@pytest.mark.mongo
143+
def test_mongo_clear_cache_for_specific_arguments():
144+
"""clear_cache can remove one Mongo cache entry by function arguments."""
145+
146+
@cachier(mongetter=_test_mongetter)
147+
def _test_mongo_caching(arg_1, arg_2):
148+
"""Some function."""
149+
return random() + arg_1 + arg_2
150+
151+
_test_mongo_caching.clear_cache()
152+
val1 = _test_mongo_caching(1, arg_2=2)
153+
val2 = _test_mongo_caching(3, arg_2=4)
154+
assert _test_mongo_caching(1, arg_2=2) == val1
155+
assert _test_mongo_caching(3, arg_2=4) == val2
156+
157+
_test_mongo_caching.clear_cache(1, arg_2=2)
158+
159+
assert _test_mongo_caching(1, arg_2=2) != val1
160+
assert _test_mongo_caching(3, arg_2=4) == val2
161+
_test_mongo_caching.clear_cache()
162+
163+
142164
@pytest.mark.mongo
143165
def test_mongo_stale_after():
144166
"""Testing MongoDB core stale_after functionality."""

0 commit comments

Comments
 (0)