Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ These parameters can be changed at any time and they will apply to all decorator
* `stale_after`
* `next_time`
* `wait_for_calc_timeout`
* `cleanup_stale`
* `cleanup_interval`

The current defaults can be fetched by calling `get_default_params`.

Expand Down Expand Up @@ -192,6 +194,17 @@ Sometimes you may want your function to trigger a calculation when it encounters

Further function calls made while the calculation is being performed will not trigger redundant calculations.

Automatic Cleanup of Stale Values
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Setting ``cleanup_stale=True`` on a decorator will spawn a background thread that periodically removes stale cache entries. The interval between cleanup runs is controlled by ``cleanup_interval`` and defaults to one day.

.. code-block:: python

@cachier(stale_after=timedelta(seconds=30), cleanup_stale=True)
def compute():
...



Working with unhashable arguments
---------------------------------
Expand Down
4 changes: 3 additions & 1 deletion src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class Params:
separate_files: bool = False
wait_for_calc_timeout: int = 0
allow_none: bool = False
cleanup_stale: bool = False
cleanup_interval: timedelta = timedelta(days=1)


_global_params = Params()
Expand Down Expand Up @@ -130,7 +132,7 @@ def set_global_params(**params: Any) -> None:
}
cachier.config._global_params = replace(
cachier.config._global_params,
**valid_params, # type: ignore[arg-type]
**valid_params,
)


Expand Down
28 changes: 27 additions & 1 deletion src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import inspect
import os
import threading
import warnings
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -120,6 +121,8 @@ def cachier(
separate_files: Optional[bool] = None,
wait_for_calc_timeout: Optional[int] = None,
allow_none: Optional[bool] = None,
cleanup_stale: Optional[bool] = None,
cleanup_interval: Optional[timedelta] = None,
):
"""Wrap as a persistent, stale-free memoization decorator.

Expand Down Expand Up @@ -183,6 +186,11 @@ def cachier(
allow_none: bool, optional
Allows storing None values in the cache. If False, functions returning
None will not be cached and are recalculated every call.
cleanup_stale: bool, optional
If True, stale cache entries are periodically deleted in a background
thread. Defaults to False.
cleanup_interval: datetime.timedelta, optional
Minimum time between automatic cleanup runs. Defaults to one day.

"""
# Check for deprecated parameters
Expand Down Expand Up @@ -236,6 +244,9 @@ def cachier(
def _cachier_decorator(func):
core.set_func(func)

last_cleanup = datetime.min
cleanup_lock = threading.Lock()

# ---
# MAINTAINER NOTE: max_age parameter
#
Expand All @@ -261,7 +272,7 @@ def _cachier_decorator(func):
# ---

def _call(*args, max_age: Optional[timedelta] = None, **kwds):
nonlocal allow_none
nonlocal allow_none, last_cleanup
_allow_none = _update_with_defaults(allow_none, "allow_none", kwds)
# print('Inside general wrapper for {}.'.format(func.__name__))
ignore_cache = _pop_kwds_with_deprecation(
Expand All @@ -280,11 +291,26 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds):
stale_after, "stale_after", kwds
)
_next_time = _update_with_defaults(next_time, "next_time", kwds)
_cleanup_flag = _update_with_defaults(
cleanup_stale, "cleanup_stale", kwds
)
_cleanup_interval_val = _update_with_defaults(
cleanup_interval, "cleanup_interval", kwds
)
# merge args expanded as kwargs and the original kwds
kwargs = _convert_args_kwargs(
func, _is_method=core.func_is_method, args=args, kwds=kwds
)

if _cleanup_flag:
now = datetime.now()
with cleanup_lock:
if now - last_cleanup >= _cleanup_interval_val:
last_cleanup = now
_get_executor().submit(
core.delete_stale_entries, _stale_after
)

_print = print if verbose else lambda x: None

# Check current global caching state dynamically
Expand Down
5 changes: 5 additions & 0 deletions src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import abc # for the _BaseCore abstract base class
import inspect
import threading
from datetime import timedelta
from typing import Callable, Optional, Tuple

from .._types import HashFunc
Expand Down Expand Up @@ -112,3 +113,7 @@ def clear_cache(self) -> None:
@abc.abstractmethod
def clear_being_calculated(self) -> None:
"""Mark all entries in this cache as not being calculated."""

@abc.abstractmethod
def delete_stale_entries(self, stale_after: timedelta) -> None:
"""Delete cache entries older than ``stale_after``."""
12 changes: 11 additions & 1 deletion src/cachier/cores/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""A memory-based caching core for cachier."""

import threading
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Tuple

from .._types import HashFunc
Expand Down Expand Up @@ -103,3 +103,13 @@ def clear_being_calculated(self) -> None:
for entry in self.cache.values():
entry._processing = False
entry._condition = None

def delete_stale_entries(self, stale_after: timedelta) -> None:
"""Remove stale entries from the in-memory cache."""
now = datetime.now()
with self.lock:
keys_to_delete = [
k for k, v in self.cache.items() if now - v.time > stale_after
]
for key in keys_to_delete:
del self.cache[key]
9 changes: 8 additions & 1 deletion src/cachier/cores/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time # to sleep when waiting on Mongo cache\
import warnings # to warn if pymongo is missing
from contextlib import suppress
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Optional, Tuple

from .._types import HashFunc, Mongetter
Expand Down Expand Up @@ -146,3 +146,10 @@ def clear_being_calculated(self) -> None:
},
update={"$set": {"processing": False}},
)

def delete_stale_entries(self, stale_after: timedelta) -> None:
"""Delete stale entries from the MongoDB cache."""
threshold = datetime.now() - stale_after
self.mongo_collection.delete_many(
filter={"func": self._func_str, "time": {"$lt": threshold}}
)
50 changes: 38 additions & 12 deletions src/cachier/cores/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import os
import pickle # for local caching
import time
from datetime import datetime
from contextlib import suppress
from datetime import datetime, timedelta
from typing import Any, Dict, Optional, Tuple, Union

import portalocker # to lock on pickle cache IO
Expand Down Expand Up @@ -260,16 +261,16 @@ def mark_entry_not_calculated(self, key: str) -> None:
cache[key]._processing = False
self._save_cache(cache)

def _create_observer(self) -> Observer:
def _create_observer(self) -> Observer: # type: ignore[valid-type]
"""Create a new observer instance."""
return Observer()

def _cleanup_observer(self, observer: Observer) -> None:
def _cleanup_observer(self, observer: Observer) -> None: # type: ignore[valid-type]
"""Clean up observer properly."""
try:
if observer.is_alive():
observer.stop()
observer.join(timeout=1.0)
if observer.is_alive(): # type: ignore[attr-defined]
observer.stop() # type: ignore[attr-defined]
observer.join(timeout=1.0) # type: ignore[attr-defined]
except Exception as e:
logging.debug("Observer cleanup failed: %s", e)

Expand All @@ -296,7 +297,7 @@ def wait_on_entry_calc(self, key: str) -> Any:
else:
raise

def _wait_with_inotify(self, key: str, filename: str) -> Any:
def _wait_with_inotify(self, key: str, filename: str) -> Any: # type: ignore[valid-type]
"""Wait for calculation using inotify with proper cleanup."""
event_handler = _PickleCore.CacheChangeHandler(
filename=filename, core=self, key=key
Expand All @@ -306,14 +307,14 @@ def _wait_with_inotify(self, key: str, filename: str) -> Any:
event_handler.inject_observer(observer)

try:
observer.schedule(
observer.schedule( # type: ignore[attr-defined]
event_handler, path=self.cache_dir, recursive=True
)
observer.start()
observer.start() # type: ignore[attr-defined]

time_spent = 0
while observer.is_alive():
observer.join(timeout=1.0)
while observer.is_alive(): # type: ignore[attr-defined]
observer.join(timeout=1.0) # type: ignore[attr-defined]
time_spent += 1
self.check_calc_timeout(time_spent)

Expand All @@ -324,7 +325,7 @@ def _wait_with_inotify(self, key: str, filename: str) -> Any:
return event_handler.value
finally:
# Always cleanup the observer
self._cleanup_observer(observer)
self._cleanup_observer(observer) # type: ignore[attr-defined]

def _wait_with_polling(self, key: str) -> Any:
"""Fallback method using polling instead of inotify."""
Expand Down Expand Up @@ -364,3 +365,28 @@ def clear_being_calculated(self) -> None:
for key in cache:
cache[key]._processing = False
self._save_cache(cache)

def delete_stale_entries(self, stale_after: timedelta) -> None:
"""Delete stale cache entries from the pickle cache."""
now = datetime.now()
if self.separate_files:
path, name = os.path.split(self.cache_fpath)
for subpath in os.listdir(path):
if not subpath.startswith(f"{name}_"):
continue
entry = self._load_cache_by_key(
hash_str=subpath.split("_")[-1]
)
if entry is not None and (now - entry.time > stale_after):
with suppress(FileNotFoundError):
os.remove(os.path.join(path, subpath))
return

with self.lock:
cache = self.get_cache_dict(reload=True)
keys_to_delete = [
k for k, v in cache.items() if now - v.time > stale_after
]
for key in keys_to_delete:
del cache[key]
self._save_cache(cache)
27 changes: 26 additions & 1 deletion src/cachier/cores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
import time
import warnings
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Callable, Optional, Tuple, Union

try:
Expand Down Expand Up @@ -223,3 +223,28 @@ def clear_being_calculated(self) -> None:
warnings.warn(
f"Redis clear_being_calculated failed: {e}", stacklevel=2
)

def delete_stale_entries(self, stale_after: timedelta) -> None:
"""Remove stale entries from the Redis cache."""
redis_client = self._resolve_redis_client()
pattern = f"{self.key_prefix}:{self._func_str}:*"
try:
keys = redis_client.keys(pattern)
threshold = datetime.now() - stale_after
for key in keys:
ts = redis_client.hget(key, "timestamp")
if ts is None:
continue
try:
ts_val = datetime.fromisoformat(ts.decode("utf-8"))
except Exception as exc:
warnings.warn(
f"Redis timestamp parse failed: {exc}", stacklevel=2
)
continue
if ts_val < threshold:
redis_client.delete(key)
except Exception as e:
warnings.warn(
f"Redis delete_stale_entries failed: {e}", stacklevel=2
)
16 changes: 15 additions & 1 deletion src/cachier/cores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pickle
import threading
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Callable, Optional, Tuple, Union

try:
Expand Down Expand Up @@ -286,3 +286,17 @@ def clear_being_calculated(self) -> None:
.values(processing=False)
)
session.commit()

def delete_stale_entries(self, stale_after: timedelta) -> None:
"""Delete stale entries from the SQL cache."""
threshold = datetime.now() - stale_after
with self._lock, self._Session() as session:
session.execute(
delete(CacheTable).where(
and_(
CacheTable.function_id == self._func_str,
CacheTable.timestamp < threshold,
)
)
)
session.commit()
49 changes: 49 additions & 0 deletions tests/test_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import pickle
import time
from dataclasses import replace
from datetime import timedelta

import pytest

import cachier
from cachier import cachier as cachier_dec

_copied_defaults = replace(cachier.get_global_params())


def setup_function() -> None:
cachier.set_global_params(**vars(_copied_defaults))


def teardown_function() -> None:
cachier.set_global_params(**vars(_copied_defaults))


@pytest.mark.pickle
def test_cleanup_stale_entries(tmp_path):
@cachier_dec(
cache_dir=tmp_path,
stale_after=timedelta(seconds=1),
cleanup_stale=True,
cleanup_interval=timedelta(seconds=0),
)
def add(x):
return x + 1

add.clear_cache()
add(1)
add(2)
fname = f".{add.__module__}.{add.__qualname__}".replace("<", "_").replace(
">", "_"
)
cache_path = os.path.join(add.cache_dpath(), fname)
with open(cache_path, "rb") as fh:
data = pickle.load(fh)
assert len(data) == 2
time.sleep(1.1)
add(1)
time.sleep(0.2)
with open(cache_path, "rb") as fh:
data = pickle.load(fh)
assert len(data) == 1
Loading
Loading