Skip to content

Commit eff402b

Browse files
committed
Organize test directories by cache backend; Add async pickle and memory tests
1 parent 4ec6ab9 commit eff402b

22 files changed

Lines changed: 280 additions & 90 deletions

src/cachier/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,23 +49,23 @@ def _get_executor(reset=False):
4949
return _get_executor.executor
5050

5151

52-
def _function_thread(core, key, func, args, kwds):
52+
def _function_thread(core: _BaseCore, key, func, args, kwds):
5353
try:
5454
func_res = func(*args, **kwds)
5555
core.set_entry(key, func_res)
5656
except BaseException as exc:
5757
print(f"Function call failed with the following exception:\n{exc}")
5858

5959

60-
async def _function_thread_async(core, key, func, args, kwds):
60+
async def _function_thread_async(core: _BaseCore, key, func, args, kwds):
6161
try:
6262
func_res = await func(*args, **kwds)
6363
await core.aset_entry(key, func_res)
6464
except BaseException as exc:
6565
print(f"Function call failed with the following exception:\n{exc}")
6666

6767

68-
def _calc_entry(core, key, func, args, kwds, printer=lambda *_: None) -> Optional[Any]:
68+
def _calc_entry(core: _BaseCore, key, func, args, kwds, printer=lambda *_: None) -> Optional[Any]:
6969
core.mark_entry_being_calculated(key)
7070
try:
7171
func_res = func(*args, **kwds)
@@ -77,7 +77,7 @@ def _calc_entry(core, key, func, args, kwds, printer=lambda *_: None) -> Optiona
7777
core.mark_entry_not_calculated(key)
7878

7979

80-
async def _calc_entry_async(core, key, func, args, kwds, printer=lambda *_: None) -> Optional[Any]:
80+
async def _calc_entry_async(core: _BaseCore, key, func, args, kwds, printer=lambda *_: None) -> Optional[Any]:
8181
await core.amark_entry_being_calculated(key)
8282
try:
8383
func_res = await func(*args, **kwds)

src/cachier/cores/memory.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ def get_entry_by_key(self, key: str, reload=False) -> Tuple[str, Optional[CacheE
2828
with self.lock:
2929
return key, self.cache.get(self._hash_func_key(key), None)
3030

31+
async def aget_entry(self, args: tuple[Any, ...], kwds: dict[str, Any]) -> Tuple[str, Optional[CacheEntry]]:
32+
key = self.get_key(args, kwds)
33+
return await self.aget_entry_by_key(key)
34+
35+
async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
36+
"""Get an entry by key."""
37+
return self.get_entry_by_key(key)
38+
3139
def set_entry(self, key: str, func_res: Any) -> bool:
3240
if not self._should_store(func_res):
3341
return False
@@ -50,6 +58,10 @@ def set_entry(self, key: str, func_res: Any) -> bool:
5058
)
5159
return True
5260

61+
async def aset_entry(self, key: str, func_res: Any) -> bool:
62+
"""Set an entry."""
63+
return self.set_entry(key, func_res)
64+
5365
def mark_entry_being_calculated(self, key: str) -> None:
5466
with self.lock:
5567
condition = threading.Condition()
@@ -67,6 +79,9 @@ def mark_entry_being_calculated(self, key: str) -> None:
6779
_condition=condition,
6880
)
6981

82+
async def amark_entry_being_calculated(self, key: str) -> None:
83+
self.mark_entry_being_calculated(key)
84+
7085
def mark_entry_not_calculated(self, key: str) -> None:
7186
hash_key = self._hash_func_key(key)
7287
with self.lock:
@@ -81,6 +96,9 @@ def mark_entry_not_calculated(self, key: str) -> None:
8196
cond.release()
8297
entry._condition = None
8398

99+
async def amark_entry_not_calculated(self, key: str) -> None:
100+
self.mark_entry_not_calculated(key)
101+
84102
def wait_on_entry_calc(self, key: str) -> Any:
85103
hash_key = self._hash_func_key(key)
86104
with self.lock: # pragma: no cover

src/cachier/cores/pickle.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,13 @@ def get_entry_by_key(self, key: str, reload: bool = False) -> Tuple[str, Optiona
182182
return key, self._load_cache_by_key(key)
183183
return key, self.get_cache_dict(reload).get(key)
184184

185+
async def aget_entry(self, args: tuple[Any, ...], kwds: dict[str, Any]) -> Tuple[str, Optional[CacheEntry]]:
186+
key = self.get_key(args, kwds)
187+
return await self.aget_entry_by_key(key)
188+
189+
async def aget_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
190+
return self.get_entry_by_key(key)
191+
185192
def set_entry(self, key: str, func_res: Any) -> bool:
186193
if not self._should_store(func_res):
187194
return False
@@ -202,6 +209,9 @@ def set_entry(self, key: str, func_res: Any) -> bool:
202209
self._save_cache(cache)
203210
return True
204211

212+
async def aset_entry(self, key: str, func_res: Any) -> bool:
213+
return self.set_entry(key, func_res)
214+
205215
def mark_entry_being_calculated_separate_files(self, key: str) -> None:
206216
self._save_cache(
207217
CacheEntry(value=None, time=datetime.now(), stale=False, _processing=True),
@@ -228,6 +238,9 @@ def mark_entry_being_calculated(self, key: str) -> None:
228238
cache[key] = CacheEntry(value=None, time=datetime.now(), stale=False, _processing=True)
229239
self._save_cache(cache)
230240

241+
async def amark_entry_being_calculated(self, key: str) -> None:
242+
self.mark_entry_being_calculated(key)
243+
231244
def mark_entry_not_calculated(self, key: str) -> None:
232245
if self.separate_files:
233246
self._mark_entry_not_calculated_separate_files(key)
@@ -238,6 +251,9 @@ def mark_entry_not_calculated(self, key: str) -> None:
238251
cache[key]._processing = False
239252
self._save_cache(cache)
240253

254+
async def amark_entry_not_calculated(self, key: str) -> None:
255+
self.mark_entry_not_calculated(key)
256+
241257
def _create_observer(self) -> Observer: # type: ignore[valid-type]
242258
"""Create a new observer instance."""
243259
return Observer()

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def cleanup_mongo_clients():
1717

1818
# Cleanup after all tests
1919
try:
20-
from tests.test_mongo_core import _test_mongetter
20+
from tests.mongo_tests.test_mongo_core import _test_mongetter
2121
except ImportError:
2222
return
2323

tests/memory_tests/__init__.py

Whitespace-only changes.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Tests for async support in the memory core."""
2+
3+
import asyncio
4+
from datetime import timedelta
5+
6+
import pytest
7+
8+
from cachier import cachier
9+
10+
11+
@pytest.mark.memory
12+
@pytest.mark.asyncio
13+
async def test_async_memory_basic_caching():
14+
"""Ensure async functions are cached by the memory backend."""
15+
call_count = 0
16+
17+
@cachier(backend="memory")
18+
async def async_memory_cached(x: int) -> int:
19+
nonlocal call_count
20+
call_count += 1
21+
await asyncio.sleep(0.01)
22+
return x + call_count
23+
24+
async_memory_cached.clear_cache()
25+
try:
26+
value1 = await async_memory_cached(3)
27+
value2 = await async_memory_cached(3)
28+
assert value1 == value2 == 4
29+
assert call_count == 1
30+
finally:
31+
async_memory_cached.clear_cache()
32+
33+
34+
@pytest.mark.memory
35+
@pytest.mark.asyncio
36+
async def test_async_memory_next_time_returns_stale_then_updates():
37+
"""Ensure next_time returns stale value and updates asynchronously."""
38+
call_count = 0
39+
40+
@cachier(
41+
backend="memory",
42+
stale_after=timedelta(milliseconds=150),
43+
next_time=True,
44+
)
45+
async def async_memory_next_time(_: int) -> int:
46+
nonlocal call_count
47+
call_count += 1
48+
await asyncio.sleep(0.05)
49+
return call_count
50+
51+
async_memory_next_time.clear_cache()
52+
try:
53+
first = await async_memory_next_time(1)
54+
assert first == 1
55+
56+
await asyncio.sleep(0.2)
57+
58+
stale = await async_memory_next_time(1)
59+
assert stale == 1
60+
61+
updated = stale
62+
for _ in range(10):
63+
await asyncio.sleep(0.05)
64+
updated = await async_memory_next_time(1)
65+
if updated > 1:
66+
break
67+
68+
assert updated > 1
69+
assert call_count >= 2
70+
await asyncio.sleep(0.1)
71+
finally:
72+
async_memory_next_time.clear_cache()

tests/mongo_tests/__init__.py

Whitespace-only changes.

tests/test_async_backend_clients.py renamed to tests/mongo_tests/test_async_mongo_core.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Tests for async client factories (redis_client / mongetter) with async functions."""
1+
"""Tests for async Mongo client with async functions."""
22

33
import asyncio
44
from contextlib import suppress
@@ -9,80 +9,6 @@
99
from cachier import cachier
1010

1111

12-
class _AsyncInMemoryRedis:
13-
"""Minimal async Redis-like client implementing required hash operations."""
14-
15-
def __init__(self):
16-
self._data: dict[str, dict[str, object]] = {}
17-
18-
async def hgetall(self, key: str) -> dict[bytes, object]:
19-
raw = self._data.get(key, {})
20-
res: dict[bytes, object] = {}
21-
for k, v in raw.items():
22-
res[k.encode("utf-8")] = v.encode("utf-8") if isinstance(v, str) else v
23-
return res
24-
25-
async def hset(self, key: str, field=None, value=None, mapping=None, **kwargs):
26-
if key not in self._data:
27-
self._data[key] = {}
28-
29-
if mapping is not None:
30-
self._data[key].update(mapping)
31-
return
32-
if field is not None and value is not None:
33-
self._data[key][field] = value
34-
return
35-
if kwargs:
36-
self._data[key].update(kwargs)
37-
38-
39-
@pytest.mark.redis
40-
@pytest.mark.asyncio
41-
async def test_async_redis_client_factory():
42-
pytest.importorskip("redis")
43-
44-
client = _AsyncInMemoryRedis()
45-
46-
async def get_redis_client():
47-
return client
48-
49-
@cachier(backend="redis", redis_client=get_redis_client)
50-
async def async_cached_redis(x: int) -> float:
51-
await asyncio.sleep(0.01)
52-
return random() + x
53-
54-
val1 = await async_cached_redis(3)
55-
val2 = await async_cached_redis(3)
56-
assert val1 == val2
57-
58-
59-
@pytest.mark.redis
60-
@pytest.mark.asyncio
61-
async def test_async_redis_client_factory_method_args_and_kwargs():
62-
pytest.importorskip("redis")
63-
64-
client = _AsyncInMemoryRedis()
65-
66-
async def get_redis_client():
67-
return client
68-
69-
call_count = 0
70-
71-
class _RedisMethods:
72-
@cachier(backend="redis", redis_client=get_redis_client)
73-
async def async_cached_redis_method_args_kwargs(self, x: int, y: int) -> int:
74-
nonlocal call_count
75-
call_count += 1
76-
await asyncio.sleep(0.01)
77-
return call_count
78-
79-
obj = _RedisMethods()
80-
val1 = await obj.async_cached_redis_method_args_kwargs(1, 2)
81-
val2 = await obj.async_cached_redis_method_args_kwargs(y=2, x=1)
82-
assert val1 == val2 == 1
83-
assert call_count == 1
84-
85-
8612
@pytest.mark.mongo
8713
@pytest.mark.filterwarnings("ignore:Python 3\\.14 will, by default, filter extracted tar archives.*:DeprecationWarning")
8814
@pytest.mark.asyncio

0 commit comments

Comments
 (0)