Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]


@dataclass
@dataclass(frozen=True, slots=True)
class RangeByteRequest:
"""Request a specific byte range"""

Expand All @@ -29,15 +29,15 @@ class RangeByteRequest:
"""The end of the byte range request (exclusive)."""


@dataclass
@dataclass(frozen=True, slots=True)
class OffsetByteRequest:
"""Request all bytes starting from a given byte offset"""

offset: int
"""The byte offset for the offset range request."""


@dataclass
@dataclass(frozen=True, slots=True)
class SuffixByteRequest:
"""Request up to the last `n` bytes"""

Expand Down
245 changes: 151 additions & 94 deletions src/zarr/experimental/cache_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
if TYPE_CHECKING:
from zarr.core.buffer.core import Buffer, BufferPrototype

# A cache entry identifier. Plain ``str`` for full-key entries that live in
# the Store-backed cache; ``(str, ByteRequest)`` for byte-range entries that
# live in the in-memory range cache.
_CacheEntryKey = str | tuple[str, ByteRequest]


@dataclass(slots=True)
class _CacheState:
cache_order: OrderedDict[str, None] = field(default_factory=OrderedDict)
cache_order: OrderedDict[_CacheEntryKey, None] = field(default_factory=OrderedDict)
current_size: int = 0
key_sizes: dict[str, int] = field(default_factory=dict)
key_sizes: dict[_CacheEntryKey, int] = field(default_factory=dict)
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
hits: int = 0
misses: int = 0
evictions: int = 0
key_insert_times: dict[str, float] = field(default_factory=dict)
key_insert_times: dict[_CacheEntryKey, float] = field(default_factory=dict)
range_cache: dict[str, dict[ByteRequest, Buffer]] = field(default_factory=dict)


class CacheStore(WrapperStore[Store]):
Expand All @@ -36,6 +42,11 @@ class CacheStore(WrapperStore[Store]):
as the cache backend. This provides persistent caching capabilities with
time-based expiration, size-based eviction, and flexible cache storage options.

Full-key reads are cached in the Store-backed cache. Byte-range reads are
cached in a separate in-memory dictionary so that partial reads never
pollute the filesystem (or other persistent backend). Both caches share
the same ``max_size`` budget and LRU eviction policy.

Parameters
----------
store : Store
Expand Down Expand Up @@ -129,21 +140,21 @@ def with_read_only(self, read_only: bool = False) -> Self:
store._state = self._state
return store

def _is_key_fresh(self, key: str) -> bool:
"""Check if a cached key is still fresh based on max_age_seconds.
def _is_key_fresh(self, entry_key: _CacheEntryKey) -> bool:
"""Check if a cached entry is still fresh based on max_age_seconds.

Uses monotonic time for accurate elapsed time measurement.
"""
if self.max_age_seconds == "infinity":
return True
now = time.monotonic()
elapsed = now - self._state.key_insert_times.get(key, 0)
elapsed = now - self._state.key_insert_times.get(entry_key, 0)
return elapsed < self.max_age_seconds

async def _accommodate_value(self, value_size: int) -> None:
"""Ensure there is enough space in the cache for a new value.

Must be called while holding self._lock.
Must be called while holding self._state.lock.
"""
if self.max_size is None:
return
Expand All @@ -154,122 +165,168 @@ async def _accommodate_value(self, value_size: int) -> None:
lru_key = next(iter(self._state.cache_order))
await self._evict_key(lru_key)

async def _evict_key(self, key: str) -> None:
"""Evict a key from the cache.

Must be called while holding self._lock.
Updates size tracking atomically with deletion.
"""
try:
key_size = self._state.key_sizes.get(key, 0)

# Delete from cache store
await self._cache.delete(key)
async def _evict_key(self, entry_key: _CacheEntryKey) -> None:
"""Evict a cache entry.

# Update tracking after successful deletion
self._remove_from_tracking(key)
self._state.current_size = max(0, self._state.current_size - key_size)
self._state.evictions += 1
Must be called while holding self._state.lock.

logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size)
except Exception:
logger.exception("_evict_key: failed to evict key %s", key)
raise # Re-raise to signal eviction failure
For ``str`` keys the entry is deleted from the Store-backed cache.
For ``(str, ByteRequest)`` keys the entry is removed from the
in-memory range cache.
"""
key_size = self._state.key_sizes.get(entry_key, 0)

async def _cache_value(self, key: str, value: Buffer) -> None:
"""Cache a value with size tracking.
if isinstance(entry_key, str):
await self._cache.delete(entry_key)
else:
base_key, byte_range = entry_key
per_key = self._state.range_cache.get(base_key)
if per_key is not None:
per_key.pop(byte_range, None)
if not per_key:
del self._state.range_cache[base_key]

self._state.cache_order.pop(entry_key, None)
self._state.key_insert_times.pop(entry_key, None)
self._state.key_sizes.pop(entry_key, None)
self._state.current_size = max(0, self._state.current_size - key_size)
self._state.evictions += 1

async def _track_entry(self, entry_key: _CacheEntryKey, value: Buffer) -> bool:
"""Register *entry_key* in the shared size / LRU tracking.

Returns ``True`` if the entry was tracked, ``False`` if the value
exceeds ``max_size`` and was skipped. Callers should roll back any
data they already stored when this returns ``False``.

This method holds the lock for the entire operation to ensure atomicity.
"""
value_size = len(value)

# Check if value exceeds max size
if self.max_size is not None and value_size > self.max_size:
logger.warning(
"_cache_value: value size %d exceeds max_size %d, skipping cache",
value_size,
self.max_size,
)
return
return False

async with self._state.lock:
# If key already exists, subtract old size first
if key in self._state.key_sizes:
old_size = self._state.key_sizes[key]
if entry_key in self._state.key_sizes:
old_size = self._state.key_sizes[entry_key]
self._state.current_size -= old_size
logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size)

# Make room for the new value (this calls _evict_key_locked internally)
# Make room for the new value
await self._accommodate_value(value_size)

# Update tracking atomically
self._state.cache_order[key] = None # OrderedDict to track access order
self._state.cache_order[entry_key] = None
self._state.current_size += value_size
self._state.key_sizes[key] = value_size
self._state.key_insert_times[key] = time.monotonic()
self._state.key_sizes[entry_key] = value_size
self._state.key_insert_times[entry_key] = time.monotonic()

logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size)
return True

async def _update_access_order(self, key: str) -> None:
async def _update_access_order(self, entry_key: _CacheEntryKey) -> None:
"""Update the access order for LRU tracking."""
if key in self._state.cache_order:
if entry_key in self._state.cache_order:
async with self._state.lock:
# Move to end (most recently used)
self._state.cache_order.move_to_end(key)
self._state.cache_order.move_to_end(entry_key)

def _remove_from_tracking(self, key: str) -> None:
"""Remove a key from all tracking structures.
def _remove_from_tracking(self, entry_key: _CacheEntryKey) -> None:
"""Remove an entry from all tracking structures.

Must be called while holding self._state.lock.
"""
self._state.cache_order.pop(key, None)
self._state.key_insert_times.pop(key, None)
self._state.key_sizes.pop(key, None)
self._state.cache_order.pop(entry_key, None)
self._state.key_insert_times.pop(entry_key, None)
self._state.key_sizes.pop(entry_key, None)

def _invalidate_range_entries(self, key: str) -> None:
"""Remove all byte-range entries for *key* from the range cache and tracking.

Must be called while holding self._state.lock.
"""
per_key = self._state.range_cache.pop(key, None)
if per_key is not None:
for byte_range in per_key:
entry_key: _CacheEntryKey = (key, byte_range)
entry_size = self._state.key_sizes.pop(entry_key, 0)
self._state.cache_order.pop(entry_key, None)
self._state.key_insert_times.pop(entry_key, None)
self._state.current_size = max(0, self._state.current_size - entry_size)

# ------------------------------------------------------------------
# get helpers
# ------------------------------------------------------------------

async def _cache_miss(
self, key: str, byte_range: ByteRequest | None, result: Buffer | None
) -> None:
"""Handle a cache miss by storing or cleaning up after a source-store fetch."""
if result is None:
if byte_range is None:
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
else:
entry_key: _CacheEntryKey = (key, byte_range)
async with self._state.lock:
per_key = self._state.range_cache.get(key)
if per_key is not None:
per_key.pop(byte_range, None)
if not per_key:
del self._state.range_cache[key]
self._remove_from_tracking(entry_key)
else:
if byte_range is None:
await self._cache.set(key, result)
await self._track_entry(key, result)
else:
entry_key = (key, byte_range)
self._state.range_cache.setdefault(key, {})[byte_range] = result
tracked = await self._track_entry(entry_key, result)
if not tracked:
# Value too large for the cache — roll back the insertion
per_key = self._state.range_cache.get(key)
if per_key is not None:
per_key.pop(byte_range, None)
if not per_key:
del self._state.range_cache[key]

async def _get_try_cache(
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
"""Try to get data from cache first, falling back to source store."""
maybe_cached_result = await self._cache.get(key, prototype, byte_range)
if maybe_cached_result is not None:
logger.debug("_get_try_cache: key %s found in cache (HIT)", key)
self._state.hits += 1
# Update access order for LRU
await self._update_access_order(key)
return maybe_cached_result
if byte_range is None:
# Full-key read — use Store-backed cache
maybe_cached = await self._cache.get(key, prototype)
if maybe_cached is not None:
self._state.hits += 1
await self._update_access_order(key)
return maybe_cached
else:
logger.debug(
"_get_try_cache: key %s not found in cache (MISS), fetching from store", key
)
self._state.misses += 1
maybe_fresh_result = await super().get(key, prototype, byte_range)
if maybe_fresh_result is None:
# Key doesn't exist in source store
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
else:
# Cache the newly fetched value
await self._cache.set(key, maybe_fresh_result)
await self._cache_value(key, maybe_fresh_result)
return maybe_fresh_result
# Byte-range read — use in-memory range cache
entry_key: _CacheEntryKey = (key, byte_range)
per_key = self._state.range_cache.get(key)
if per_key is not None:
cached_buf = per_key.get(byte_range)
if cached_buf is not None:
self._state.hits += 1
await self._update_access_order(entry_key)
return cached_buf

# Cache miss — fetch from source store
self._state.misses += 1
result = await super().get(key, prototype, byte_range)
await self._cache_miss(key, byte_range, result)
return result

async def _get_no_cache(
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
"""Get data directly from source store and update cache."""
self._state.misses += 1
maybe_fresh_result = await super().get(key, prototype, byte_range)
if maybe_fresh_result is None:
# Key doesn't exist in source, remove from cache and tracking
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
else:
logger.debug("_get_no_cache: key %s found in store, setting in cache", key)
await self._cache.set(key, maybe_fresh_result)
await self._cache_value(key, maybe_fresh_result)
return maybe_fresh_result
result = await super().get(key, prototype, byte_range)
await self._cache_miss(key, byte_range, result)
return result

async def get(
self,
Expand All @@ -294,11 +351,10 @@ async def get(
Buffer | None
The retrieved data, or None if not found
"""
if not self._is_key_fresh(key):
logger.debug("get: key %s is not fresh, fetching from store", key)
entry_key: _CacheEntryKey = (key, byte_range) if byte_range is not None else key
if not self._is_key_fresh(entry_key):
return await self._get_no_cache(key, prototype, byte_range)
else:
logger.debug("get: key %s is fresh, trying cache", key)
return await self._get_try_cache(key, prototype, byte_range)

async def set(self, key: str, value: Buffer) -> None:
Expand All @@ -312,14 +368,14 @@ async def set(self, key: str, value: Buffer) -> None:
value : Buffer
The data to store
"""
logger.debug("set: setting key %s in store", key)
await super().set(key, value)
# Invalidate all cached byte-range entries (source data changed)
async with self._state.lock:
self._invalidate_range_entries(key)
if self.cache_set_data:
logger.debug("set: setting key %s in cache", key)
await self._cache.set(key, value)
await self._cache_value(key, value)
await self._track_entry(key, value)
else:
logger.debug("set: deleting key %s from cache", key)
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
Expand All @@ -333,9 +389,10 @@ async def delete(self, key: str) -> None:
key : str
The key to delete
"""
logger.debug("delete: deleting key %s from store", key)
await super().delete(key)
logger.debug("delete: deleting key %s from cache", key)
# Invalidate all cached byte-range entries
async with self._state.lock:
self._invalidate_range_entries(key)
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
Expand Down Expand Up @@ -377,8 +434,8 @@ async def clear_cache(self) -> None:
self._state.key_insert_times.clear()
self._state.cache_order.clear()
self._state.key_sizes.clear()
self._state.range_cache.clear()
self._state.current_size = 0
logger.debug("clear_cache: cleared all cache data")

def __repr__(self) -> str:
"""Return string representation of the cache store."""
Expand Down
Loading