Skip to content

Commit b8094d8

Browse files
authored
fix: implement separate cache for byte-range-requests (zarr-developers#3710)
* implement separate cache for byte-range-requests * changelog
1 parent d926e43 commit b8094d8

File tree

4 files changed

+307
-110
lines changed

4 files changed

+307
-110
lines changed

changes/3710.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add a dedicated in-memory cache for byte-range requests to the experimental `CacheStore`.

src/zarr/abc/store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]
2020

2121

22-
@dataclass
22+
@dataclass(frozen=True, slots=True)
2323
class RangeByteRequest:
2424
"""Request a specific byte range"""
2525

@@ -29,15 +29,15 @@ class RangeByteRequest:
2929
"""The end of the byte range request (exclusive)."""
3030

3131

32-
@dataclass
32+
@dataclass(frozen=True, slots=True)
3333
class OffsetByteRequest:
3434
"""Request all bytes starting from a given byte offset"""
3535

3636
offset: int
3737
"""The byte offset for the offset range request."""
3838

3939

40-
@dataclass
40+
@dataclass(frozen=True, slots=True)
4141
class SuffixByteRequest:
4242
"""Request up to the last `n` bytes"""
4343

src/zarr/experimental/cache_store.py

Lines changed: 151 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,23 @@
1515
if TYPE_CHECKING:
1616
from zarr.core.buffer.core import Buffer, BufferPrototype
1717

18+
# A cache entry identifier. Plain ``str`` for full-key entries that live in
19+
# the Store-backed cache; ``(str, ByteRequest)`` for byte-range entries that
20+
# live in the in-memory range cache.
21+
_CacheEntryKey = str | tuple[str, ByteRequest]
22+
1823

1924
@dataclass(slots=True)
2025
class _CacheState:
21-
cache_order: OrderedDict[str, None] = field(default_factory=OrderedDict)
26+
cache_order: OrderedDict[_CacheEntryKey, None] = field(default_factory=OrderedDict)
2227
current_size: int = 0
23-
key_sizes: dict[str, int] = field(default_factory=dict)
28+
key_sizes: dict[_CacheEntryKey, int] = field(default_factory=dict)
2429
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
2530
hits: int = 0
2631
misses: int = 0
2732
evictions: int = 0
28-
key_insert_times: dict[str, float] = field(default_factory=dict)
33+
key_insert_times: dict[_CacheEntryKey, float] = field(default_factory=dict)
34+
range_cache: dict[str, dict[ByteRequest, Buffer]] = field(default_factory=dict)
2935

3036

3137
class CacheStore(WrapperStore[Store]):
@@ -36,6 +42,11 @@ class CacheStore(WrapperStore[Store]):
3642
as the cache backend. This provides persistent caching capabilities with
3743
time-based expiration, size-based eviction, and flexible cache storage options.
3844
45+
Full-key reads are cached in the Store-backed cache. Byte-range reads are
46+
cached in a separate in-memory dictionary so that partial reads never
47+
pollute the filesystem (or other persistent backend). Both caches share
48+
the same ``max_size`` budget and LRU eviction policy.
49+
3950
Parameters
4051
----------
4152
store : Store
@@ -129,21 +140,21 @@ def with_read_only(self, read_only: bool = False) -> Self:
129140
store._state = self._state
130141
return store
131142

132-
def _is_key_fresh(self, key: str) -> bool:
133-
"""Check if a cached key is still fresh based on max_age_seconds.
143+
def _is_key_fresh(self, entry_key: _CacheEntryKey) -> bool:
144+
"""Check if a cached entry is still fresh based on max_age_seconds.
134145
135146
Uses monotonic time for accurate elapsed time measurement.
136147
"""
137148
if self.max_age_seconds == "infinity":
138149
return True
139150
now = time.monotonic()
140-
elapsed = now - self._state.key_insert_times.get(key, 0)
151+
elapsed = now - self._state.key_insert_times.get(entry_key, 0)
141152
return elapsed < self.max_age_seconds
142153

143154
async def _accommodate_value(self, value_size: int) -> None:
144155
"""Ensure there is enough space in the cache for a new value.
145156
146-
Must be called while holding self._lock.
157+
Must be called while holding self._state.lock.
147158
"""
148159
if self.max_size is None:
149160
return
@@ -154,122 +165,168 @@ async def _accommodate_value(self, value_size: int) -> None:
154165
lru_key = next(iter(self._state.cache_order))
155166
await self._evict_key(lru_key)
156167

157-
async def _evict_key(self, key: str) -> None:
158-
"""Evict a key from the cache.
159-
160-
Must be called while holding self._lock.
161-
Updates size tracking atomically with deletion.
162-
"""
163-
try:
164-
key_size = self._state.key_sizes.get(key, 0)
165-
166-
# Delete from cache store
167-
await self._cache.delete(key)
168+
async def _evict_key(self, entry_key: _CacheEntryKey) -> None:
169+
"""Evict a cache entry.
168170
169-
# Update tracking after successful deletion
170-
self._remove_from_tracking(key)
171-
self._state.current_size = max(0, self._state.current_size - key_size)
172-
self._state.evictions += 1
171+
Must be called while holding self._state.lock.
173172
174-
logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size)
175-
except Exception:
176-
logger.exception("_evict_key: failed to evict key %s", key)
177-
raise # Re-raise to signal eviction failure
173+
For ``str`` keys the entry is deleted from the Store-backed cache.
174+
For ``(str, ByteRequest)`` keys the entry is removed from the
175+
in-memory range cache.
176+
"""
177+
key_size = self._state.key_sizes.get(entry_key, 0)
178178

179-
async def _cache_value(self, key: str, value: Buffer) -> None:
180-
"""Cache a value with size tracking.
179+
if isinstance(entry_key, str):
180+
await self._cache.delete(entry_key)
181+
else:
182+
base_key, byte_range = entry_key
183+
per_key = self._state.range_cache.get(base_key)
184+
if per_key is not None:
185+
per_key.pop(byte_range, None)
186+
if not per_key:
187+
del self._state.range_cache[base_key]
188+
189+
self._state.cache_order.pop(entry_key, None)
190+
self._state.key_insert_times.pop(entry_key, None)
191+
self._state.key_sizes.pop(entry_key, None)
192+
self._state.current_size = max(0, self._state.current_size - key_size)
193+
self._state.evictions += 1
194+
195+
async def _track_entry(self, entry_key: _CacheEntryKey, value: Buffer) -> bool:
196+
"""Register *entry_key* in the shared size / LRU tracking.
197+
198+
Returns ``True`` if the entry was tracked, ``False`` if the value
199+
exceeds ``max_size`` and was skipped. Callers should roll back any
200+
data they already stored when this returns ``False``.
181201
182202
This method holds the lock for the entire operation to ensure atomicity.
183203
"""
184204
value_size = len(value)
185205

186206
# Check if value exceeds max size
187207
if self.max_size is not None and value_size > self.max_size:
188-
logger.warning(
189-
"_cache_value: value size %d exceeds max_size %d, skipping cache",
190-
value_size,
191-
self.max_size,
192-
)
193-
return
208+
return False
194209

195210
async with self._state.lock:
196211
# If key already exists, subtract old size first
197-
if key in self._state.key_sizes:
198-
old_size = self._state.key_sizes[key]
212+
if entry_key in self._state.key_sizes:
213+
old_size = self._state.key_sizes[entry_key]
199214
self._state.current_size -= old_size
200-
logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size)
201215

202-
# Make room for the new value (this calls _evict_key_locked internally)
216+
# Make room for the new value
203217
await self._accommodate_value(value_size)
204218

205219
# Update tracking atomically
206-
self._state.cache_order[key] = None # OrderedDict to track access order
220+
self._state.cache_order[entry_key] = None
207221
self._state.current_size += value_size
208-
self._state.key_sizes[key] = value_size
209-
self._state.key_insert_times[key] = time.monotonic()
222+
self._state.key_sizes[entry_key] = value_size
223+
self._state.key_insert_times[entry_key] = time.monotonic()
210224

211-
logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size)
225+
return True
212226

213-
async def _update_access_order(self, key: str) -> None:
227+
async def _update_access_order(self, entry_key: _CacheEntryKey) -> None:
214228
"""Update the access order for LRU tracking."""
215-
if key in self._state.cache_order:
229+
if entry_key in self._state.cache_order:
216230
async with self._state.lock:
217-
# Move to end (most recently used)
218-
self._state.cache_order.move_to_end(key)
231+
self._state.cache_order.move_to_end(entry_key)
219232

220-
def _remove_from_tracking(self, key: str) -> None:
221-
"""Remove a key from all tracking structures.
233+
def _remove_from_tracking(self, entry_key: _CacheEntryKey) -> None:
234+
"""Remove an entry from all tracking structures.
222235
223236
Must be called while holding self._state.lock.
224237
"""
225-
self._state.cache_order.pop(key, None)
226-
self._state.key_insert_times.pop(key, None)
227-
self._state.key_sizes.pop(key, None)
238+
self._state.cache_order.pop(entry_key, None)
239+
self._state.key_insert_times.pop(entry_key, None)
240+
self._state.key_sizes.pop(entry_key, None)
241+
242+
def _invalidate_range_entries(self, key: str) -> None:
243+
"""Remove all byte-range entries for *key* from the range cache and tracking.
244+
245+
Must be called while holding self._state.lock.
246+
"""
247+
per_key = self._state.range_cache.pop(key, None)
248+
if per_key is not None:
249+
for byte_range in per_key:
250+
entry_key: _CacheEntryKey = (key, byte_range)
251+
entry_size = self._state.key_sizes.pop(entry_key, 0)
252+
self._state.cache_order.pop(entry_key, None)
253+
self._state.key_insert_times.pop(entry_key, None)
254+
self._state.current_size = max(0, self._state.current_size - entry_size)
255+
256+
# ------------------------------------------------------------------
257+
# get helpers
258+
# ------------------------------------------------------------------
259+
260+
async def _cache_miss(
261+
self, key: str, byte_range: ByteRequest | None, result: Buffer | None
262+
) -> None:
263+
"""Handle a cache miss by storing or cleaning up after a source-store fetch."""
264+
if result is None:
265+
if byte_range is None:
266+
await self._cache.delete(key)
267+
async with self._state.lock:
268+
self._remove_from_tracking(key)
269+
else:
270+
entry_key: _CacheEntryKey = (key, byte_range)
271+
async with self._state.lock:
272+
per_key = self._state.range_cache.get(key)
273+
if per_key is not None:
274+
per_key.pop(byte_range, None)
275+
if not per_key:
276+
del self._state.range_cache[key]
277+
self._remove_from_tracking(entry_key)
278+
else:
279+
if byte_range is None:
280+
await self._cache.set(key, result)
281+
await self._track_entry(key, result)
282+
else:
283+
entry_key = (key, byte_range)
284+
self._state.range_cache.setdefault(key, {})[byte_range] = result
285+
tracked = await self._track_entry(entry_key, result)
286+
if not tracked:
287+
# Value too large for the cache — roll back the insertion
288+
per_key = self._state.range_cache.get(key)
289+
if per_key is not None:
290+
per_key.pop(byte_range, None)
291+
if not per_key:
292+
del self._state.range_cache[key]
228293

229294
async def _get_try_cache(
230295
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
231296
) -> Buffer | None:
232297
"""Try to get data from cache first, falling back to source store."""
233-
maybe_cached_result = await self._cache.get(key, prototype, byte_range)
234-
if maybe_cached_result is not None:
235-
logger.debug("_get_try_cache: key %s found in cache (HIT)", key)
236-
self._state.hits += 1
237-
# Update access order for LRU
238-
await self._update_access_order(key)
239-
return maybe_cached_result
298+
if byte_range is None:
299+
# Full-key read — use Store-backed cache
300+
maybe_cached = await self._cache.get(key, prototype)
301+
if maybe_cached is not None:
302+
self._state.hits += 1
303+
await self._update_access_order(key)
304+
return maybe_cached
240305
else:
241-
logger.debug(
242-
"_get_try_cache: key %s not found in cache (MISS), fetching from store", key
243-
)
244-
self._state.misses += 1
245-
maybe_fresh_result = await super().get(key, prototype, byte_range)
246-
if maybe_fresh_result is None:
247-
# Key doesn't exist in source store
248-
await self._cache.delete(key)
249-
async with self._state.lock:
250-
self._remove_from_tracking(key)
251-
else:
252-
# Cache the newly fetched value
253-
await self._cache.set(key, maybe_fresh_result)
254-
await self._cache_value(key, maybe_fresh_result)
255-
return maybe_fresh_result
306+
# Byte-range read — use in-memory range cache
307+
entry_key: _CacheEntryKey = (key, byte_range)
308+
per_key = self._state.range_cache.get(key)
309+
if per_key is not None:
310+
cached_buf = per_key.get(byte_range)
311+
if cached_buf is not None:
312+
self._state.hits += 1
313+
await self._update_access_order(entry_key)
314+
return cached_buf
315+
316+
# Cache miss — fetch from source store
317+
self._state.misses += 1
318+
result = await super().get(key, prototype, byte_range)
319+
await self._cache_miss(key, byte_range, result)
320+
return result
256321

257322
async def _get_no_cache(
258323
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
259324
) -> Buffer | None:
260325
"""Get data directly from source store and update cache."""
261326
self._state.misses += 1
262-
maybe_fresh_result = await super().get(key, prototype, byte_range)
263-
if maybe_fresh_result is None:
264-
# Key doesn't exist in source, remove from cache and tracking
265-
await self._cache.delete(key)
266-
async with self._state.lock:
267-
self._remove_from_tracking(key)
268-
else:
269-
logger.debug("_get_no_cache: key %s found in store, setting in cache", key)
270-
await self._cache.set(key, maybe_fresh_result)
271-
await self._cache_value(key, maybe_fresh_result)
272-
return maybe_fresh_result
327+
result = await super().get(key, prototype, byte_range)
328+
await self._cache_miss(key, byte_range, result)
329+
return result
273330

274331
async def get(
275332
self,
@@ -294,11 +351,10 @@ async def get(
294351
Buffer | None
295352
The retrieved data, or None if not found
296353
"""
297-
if not self._is_key_fresh(key):
298-
logger.debug("get: key %s is not fresh, fetching from store", key)
354+
entry_key: _CacheEntryKey = (key, byte_range) if byte_range is not None else key
355+
if not self._is_key_fresh(entry_key):
299356
return await self._get_no_cache(key, prototype, byte_range)
300357
else:
301-
logger.debug("get: key %s is fresh, trying cache", key)
302358
return await self._get_try_cache(key, prototype, byte_range)
303359

304360
async def set(self, key: str, value: Buffer) -> None:
@@ -312,14 +368,14 @@ async def set(self, key: str, value: Buffer) -> None:
312368
value : Buffer
313369
The data to store
314370
"""
315-
logger.debug("set: setting key %s in store", key)
316371
await super().set(key, value)
372+
# Invalidate all cached byte-range entries (source data changed)
373+
async with self._state.lock:
374+
self._invalidate_range_entries(key)
317375
if self.cache_set_data:
318-
logger.debug("set: setting key %s in cache", key)
319376
await self._cache.set(key, value)
320-
await self._cache_value(key, value)
377+
await self._track_entry(key, value)
321378
else:
322-
logger.debug("set: deleting key %s from cache", key)
323379
await self._cache.delete(key)
324380
async with self._state.lock:
325381
self._remove_from_tracking(key)
@@ -333,9 +389,10 @@ async def delete(self, key: str) -> None:
333389
key : str
334390
The key to delete
335391
"""
336-
logger.debug("delete: deleting key %s from store", key)
337392
await super().delete(key)
338-
logger.debug("delete: deleting key %s from cache", key)
393+
# Invalidate all cached byte-range entries
394+
async with self._state.lock:
395+
self._invalidate_range_entries(key)
339396
await self._cache.delete(key)
340397
async with self._state.lock:
341398
self._remove_from_tracking(key)
@@ -377,8 +434,8 @@ async def clear_cache(self) -> None:
377434
self._state.key_insert_times.clear()
378435
self._state.cache_order.clear()
379436
self._state.key_sizes.clear()
437+
self._state.range_cache.clear()
380438
self._state.current_size = 0
381-
logger.debug("clear_cache: cleared all cache data")
382439

383440
def __repr__(self) -> str:
384441
"""Return string representation of the cache store."""

0 commit comments

Comments
 (0)