Skip to content

Commit 970e848

Browse files
feat(cache): implement get_or_set_cache with stampede protection and stale-while-revalidate
- Added get_or_set_cache function to manage caching with background refresh. - Increased cache TTL to 24 hours and added refresh_after setting. - Implemented per-key locks to prevent cache stampede. - Updated cache management in libraries, specs, stats, and plots endpoints.
1 parent c5f48d6 commit 970e848

File tree

10 files changed

+444
-107
lines changed

10 files changed

+444
-107
lines changed

api/cache.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,44 @@
22
Caching utilities for pyplots API.
33
44
Centralized cache management with consistent key patterns.
5+
Includes stampede protection (per-key asyncio.Lock) and
6+
stale-while-revalidate (background refresh before TTL expiry).
57
"""
68

7-
from typing import Any
9+
import asyncio
10+
import logging
11+
import time
12+
from collections.abc import Awaitable, Callable
13+
from typing import Any, TypeVar
814

915
from cachetools import TTLCache
1016

1117
from core.config import settings
1218

1319

20+
T = TypeVar("T")
21+
22+
logger = logging.getLogger(__name__)
23+
1424
# Global cache instance (configured via settings)
1525
_cache: TTLCache = TTLCache(maxsize=settings.cache_maxsize, ttl=settings.cache_ttl)
1626

27+
# Per-key locks to prevent cache stampede (~20-30 unique keys, memory negligible)
28+
_locks: dict[str, asyncio.Lock] = {}
29+
30+
# Timestamps for stale-while-revalidate (key -> monotonic time of last set)
31+
_timestamps: dict[str, float] = {}
32+
33+
34+
def _get_lock(key: str) -> asyncio.Lock:
35+
"""Get or create a lock for a specific cache key.
36+
37+
Safe because asyncio is single-threaded — no race on dict access.
38+
"""
39+
if key not in _locks:
40+
_locks[key] = asyncio.Lock()
41+
return _locks[key]
42+
1743

1844
def cache_key(*parts: str) -> str:
1945
"""
@@ -54,6 +80,13 @@ def set_cache(key: str, value: Any) -> None:
5480
value: Value to cache.
5581
"""
5682
_cache[key] = value
83+
_timestamps[key] = time.monotonic()
84+
85+
86+
def cache_age(key: str) -> float | None:
87+
"""Seconds since key was last set, or None if not tracked."""
88+
ts = _timestamps.get(key)
89+
return time.monotonic() - ts if ts is not None else None
5790

5891

5992
def clear_cache() -> None:
@@ -67,6 +100,7 @@ def clear_cache() -> None:
67100
>>> clear_cache() # Invalidates all cached responses
68101
"""
69102
_cache.clear()
103+
_timestamps.clear()
70104

71105

72106
def clear_cache_by_pattern(pattern: str) -> int:
@@ -88,6 +122,7 @@ def clear_cache_by_pattern(pattern: str) -> int:
88122
keys_to_delete = [key for key in _cache.keys() if pattern in key]
89123
for key in keys_to_delete:
90124
del _cache[key]
125+
_timestamps.pop(key, None)
91126
return len(keys_to_delete)
92127

93128

@@ -152,3 +187,71 @@ def get_cache_stats() -> dict:
152187
{"size": 42, "maxsize": 1000, "ttl": 600}
153188
"""
154189
return {"size": len(_cache), "maxsize": _cache.maxsize, "ttl": _cache.ttl}
190+
191+
192+
# ---------------------------------------------------------------------------
193+
# Stampede protection + stale-while-revalidate
194+
# ---------------------------------------------------------------------------
195+
196+
197+
async def get_or_set_cache(
198+
key: str,
199+
factory: Callable[[], Awaitable[T]],
200+
*,
201+
refresh_after: float | None = None,
202+
refresh_factory: Callable[[], Awaitable[T]] | None = None,
203+
) -> T:
204+
"""Get cached value or compute it. Prevents stampede via per-key lock.
205+
206+
If *refresh_after* is set and the cached entry is older than that many
207+
seconds, a background refresh is scheduled and the stale value is
208+
returned immediately (stale-while-revalidate).
209+
210+
Args:
211+
key: Cache key.
212+
factory: Async callable that produces the value (e.g. DB query).
213+
Used for cold-miss (inline). May capture a request-scoped DB session.
214+
refresh_after: Seconds after which to trigger background refresh.
215+
refresh_factory: Standalone async callable for background refresh.
216+
Must create its own DB session (via get_db_context). Only used
217+
when refresh_after is set. Falls back to *factory* if not provided.
218+
"""
219+
cached = get_cache(key)
220+
if cached is not None:
221+
# Stale-while-revalidate: schedule background refresh if stale
222+
if refresh_after is not None:
223+
age = cache_age(key)
224+
if age is not None and age > refresh_after:
225+
_schedule_refresh(key, refresh_factory or factory)
226+
return cached
227+
228+
# Cold miss — must await. Lock prevents stampede.
229+
async with _get_lock(key):
230+
# Double-check after acquiring lock
231+
cached = get_cache(key)
232+
if cached is not None:
233+
return cached
234+
result = await factory()
235+
set_cache(key, result)
236+
return result
237+
238+
239+
def _schedule_refresh(key: str, factory: Callable[[], Awaitable[Any]]) -> None:
240+
"""Schedule a background cache refresh if one isn't already running."""
241+
refresh_key = f"_refresh:{key}"
242+
lock = _get_lock(refresh_key)
243+
if lock.locked():
244+
return # refresh already in progress
245+
asyncio.create_task(_background_refresh(key, factory, lock))
246+
247+
248+
async def _background_refresh(
249+
key: str, factory: Callable[[], Awaitable[Any]], lock: asyncio.Lock
250+
) -> None:
251+
"""Run factory in background and update cache. Errors are logged, not raised."""
252+
async with lock:
253+
try:
254+
result = await factory()
255+
set_cache(key, result)
256+
except Exception:
257+
logger.warning("Background cache refresh failed for key: %s", key, exc_info=True)

api/routers/libraries.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,38 @@
33
from fastapi import APIRouter, Depends
44
from sqlalchemy.ext.asyncio import AsyncSession
55

6-
from api.cache import cache_key, get_cache, set_cache
6+
from api.cache import cache_key, get_cache, get_or_set_cache, set_cache
77
from api.dependencies import optional_db, require_db
88
from api.exceptions import raise_not_found
9+
from core.config import settings
910
from core.constants import LIBRARIES_METADATA, SUPPORTED_LIBRARIES
1011
from core.database import LibraryRepository, SpecRepository
12+
from core.database.connection import get_db_context
1113
from core.utils import strip_noqa_comments
1214

1315

1416
router = APIRouter(tags=["libraries"])
1517

1618

19+
async def _refresh_libraries() -> dict:
20+
"""Standalone factory for background refresh (creates own DB session)."""
21+
async with get_db_context() as db:
22+
repo = LibraryRepository(db)
23+
libraries = await repo.get_all()
24+
return {
25+
"libraries": [
26+
{
27+
"id": lib.id,
28+
"name": lib.name,
29+
"version": lib.version,
30+
"documentation_url": lib.documentation_url,
31+
"description": lib.description,
32+
}
33+
for lib in libraries
34+
]
35+
}
36+
37+
1738
@router.get("/libraries")
1839
async def get_libraries(db: AsyncSession | None = Depends(optional_db)):
1940
"""
@@ -24,28 +45,28 @@ async def get_libraries(db: AsyncSession | None = Depends(optional_db)):
2445
if db is None:
2546
return {"libraries": LIBRARIES_METADATA}
2647

27-
key = cache_key("libraries")
28-
cached = get_cache(key)
29-
if cached:
30-
return cached
31-
32-
repo = LibraryRepository(db)
33-
libraries = await repo.get_all()
34-
35-
result = {
36-
"libraries": [
37-
{
38-
"id": lib.id,
39-
"name": lib.name,
40-
"version": lib.version,
41-
"documentation_url": lib.documentation_url,
42-
"description": lib.description,
43-
}
44-
for lib in libraries
45-
]
46-
}
47-
set_cache(key, result)
48-
return result
48+
async def _fetch() -> dict:
49+
repo = LibraryRepository(db)
50+
libraries = await repo.get_all()
51+
return {
52+
"libraries": [
53+
{
54+
"id": lib.id,
55+
"name": lib.name,
56+
"version": lib.version,
57+
"documentation_url": lib.documentation_url,
58+
"description": lib.description,
59+
}
60+
for lib in libraries
61+
]
62+
}
63+
64+
return await get_or_set_cache(
65+
cache_key("libraries"),
66+
_fetch,
67+
refresh_after=settings.cache_refresh_after,
68+
refresh_factory=_refresh_libraries,
69+
)
4970

5071

5172
@router.get("/libraries/{library_id}/images")

api/routers/plots.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlalchemy.exc import SQLAlchemyError
88
from sqlalchemy.ext.asyncio import AsyncSession
99

10-
from api.cache import get_cache, set_cache
10+
from api.cache import get_or_set_cache
1111
from api.dependencies import require_db
1212
from api.exceptions import DatabaseQueryError
1313
from api.schemas import FilteredPlotsResponse
@@ -405,43 +405,30 @@ async def get_filtered_plots(
405405
"""
406406
# Parse query parameters
407407
filter_groups = _parse_filter_groups(request)
408+
cache_k = _build_cache_key(filter_groups)
408409

409-
# Check cache (cache stores unpaginated result; pagination applied after)
410-
cache_key = _build_cache_key(filter_groups)
411-
cached: FilteredPlotsResponse | None = None
412-
try:
413-
cached = get_cache(cache_key)
414-
except Exception as e:
415-
logger.warning("Cache read failed for key %s: %s", cache_key, e)
416-
417-
if cached is None:
418-
# Fetch data from database
410+
async def _fetch_filtered() -> FilteredPlotsResponse:
419411
try:
420412
repo = SpecRepository(db)
421413
all_specs = await repo.get_all()
422414
except SQLAlchemyError as e:
423415
logger.error("Database query failed in get_filtered_plots: %s", e)
424416
raise DatabaseQueryError("fetch_specs", str(e)) from e
425417

426-
# Build data structures
427418
spec_lookup = _build_spec_lookup(all_specs)
428419
impl_lookup = _build_impl_lookup(all_specs)
429420
all_images = _collect_all_images(all_specs)
430421
spec_id_to_tags = {spec_id: spec_data["tags"] for spec_id, spec_data in spec_lookup.items()}
431422

432-
# Filter images
433423
filtered_images = _filter_images(all_images, filter_groups, spec_lookup, impl_lookup)
434424

435-
# Calculate counts (always from ALL filtered images, not paginated)
436425
global_counts = _calculate_global_counts(all_specs)
437426
counts = _calculate_contextual_counts(filtered_images, spec_id_to_tags, impl_lookup)
438427
or_counts = _calculate_or_counts(filter_groups, all_images, spec_id_to_tags, spec_lookup, impl_lookup)
439428

440-
# Build spec_id -> title mapping for search/tooltips
441429
spec_titles = {spec_id: data["spec"].title for spec_id, data in spec_lookup.items() if data["spec"].title}
442430

443-
# Cache the full (unpaginated) result
444-
cached = FilteredPlotsResponse(
431+
return FilteredPlotsResponse(
445432
total=len(filtered_images),
446433
images=filtered_images,
447434
counts=counts,
@@ -450,10 +437,8 @@ async def get_filtered_plots(
450437
specTitles=spec_titles,
451438
)
452439

453-
try:
454-
set_cache(cache_key, cached)
455-
except Exception as e:
456-
logger.warning("Cache write failed for key %s: %s", cache_key, e)
440+
# get_or_set_cache provides stampede lock (no refresh_after — too many filter key variants)
441+
cached = await get_or_set_cache(cache_k, _fetch_filtered)
457442

458443
# Apply pagination on top of (possibly cached) result
459444
paginated = cached.images[offset : offset + limit] if limit else cached.images[offset:]

api/routers/specs.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,33 @@
33
from fastapi import APIRouter, Depends
44
from sqlalchemy.ext.asyncio import AsyncSession
55

6-
from api.cache import cache_key, get_cache, set_cache
6+
from api.cache import cache_key, get_cache, get_or_set_cache, set_cache
77
from api.dependencies import require_db
88
from api.exceptions import raise_not_found
99
from api.schemas import ImplementationResponse, SpecDetailResponse, SpecListItem
10+
from core.config import settings
1011
from core.database import SpecRepository
12+
from core.database.connection import get_db_context
1113
from core.utils import strip_noqa_comments
1214

1315

1416
router = APIRouter(tags=["specs"])
1517

1618

19+
async def _refresh_specs_list() -> list[SpecListItem]:
20+
"""Standalone factory for background refresh (creates own DB session)."""
21+
async with get_db_context() as db:
22+
repo = SpecRepository(db)
23+
specs = await repo.get_all()
24+
return [
25+
SpecListItem(
26+
id=spec.id, title=spec.title, description=spec.description, tags=spec.tags, library_count=len(spec.impls)
27+
)
28+
for spec in specs
29+
if spec.impls
30+
]
31+
32+
1733
@router.get("/specs", response_model=list[SpecListItem])
1834
async def get_specs(db: AsyncSession = Depends(require_db)):
1935
"""
@@ -22,24 +38,23 @@ async def get_specs(db: AsyncSession = Depends(require_db)):
2238
Returns only specs that have at least one implementation.
2339
"""
2440

25-
key = cache_key("specs_list")
26-
cached = get_cache(key)
27-
if cached:
28-
return cached
29-
30-
repo = SpecRepository(db)
31-
specs = await repo.get_all()
32-
33-
# Only return specs with at least one implementation
34-
result = [
35-
SpecListItem(
36-
id=spec.id, title=spec.title, description=spec.description, tags=spec.tags, library_count=len(spec.impls)
37-
)
38-
for spec in specs
39-
if spec.impls # Filter: only specs with implementations
40-
]
41-
set_cache(key, result)
42-
return result
41+
async def _fetch() -> list[SpecListItem]:
42+
repo = SpecRepository(db)
43+
specs = await repo.get_all()
44+
return [
45+
SpecListItem(
46+
id=spec.id, title=spec.title, description=spec.description, tags=spec.tags, library_count=len(spec.impls)
47+
)
48+
for spec in specs
49+
if spec.impls
50+
]
51+
52+
return await get_or_set_cache(
53+
cache_key("specs_list"),
54+
_fetch,
55+
refresh_after=settings.cache_refresh_after,
56+
refresh_factory=_refresh_specs_list,
57+
)
4358

4459

4560
@router.get("/specs/{spec_id}", response_model=SpecDetailResponse)

0 commit comments

Comments
 (0)