Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
__all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"]


@dataclass
@dataclass(frozen=True, slots=True)
class RangeByteRequest:
"""Request a specific byte range"""

Expand All @@ -29,15 +29,15 @@ class RangeByteRequest:
"""The end of the byte range request (exclusive)."""


@dataclass
@dataclass(frozen=True, slots=True)
class OffsetByteRequest:
"""Request all bytes starting from a given byte offset"""

offset: int
"""The byte offset for the offset range request."""


@dataclass
@dataclass(frozen=True, slots=True)
class SuffixByteRequest:
"""Request up to the last `n` bytes"""

Expand Down
245 changes: 151 additions & 94 deletions src/zarr/experimental/cache_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
if TYPE_CHECKING:
from zarr.core.buffer.core import Buffer, BufferPrototype

# A cache entry identifier. Plain ``str`` for full-key entries that live in
# the Store-backed cache; ``(str, ByteRequest)`` for byte-range entries that
# live in the in-memory range cache.
_CacheEntryKey = str | tuple[str, ByteRequest]


@dataclass(slots=True)
class _CacheState:
cache_order: OrderedDict[str, None] = field(default_factory=OrderedDict)
cache_order: OrderedDict[_CacheEntryKey, None] = field(default_factory=OrderedDict)
current_size: int = 0
key_sizes: dict[str, int] = field(default_factory=dict)
key_sizes: dict[_CacheEntryKey, int] = field(default_factory=dict)
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
hits: int = 0
misses: int = 0
evictions: int = 0
key_insert_times: dict[str, float] = field(default_factory=dict)
key_insert_times: dict[_CacheEntryKey, float] = field(default_factory=dict)
range_cache: dict[str, dict[ByteRequest, Buffer]] = field(default_factory=dict)


class CacheStore(WrapperStore[Store]):
Expand All @@ -36,6 +42,11 @@ class CacheStore(WrapperStore[Store]):
as the cache backend. This provides persistent caching capabilities with
time-based expiration, size-based eviction, and flexible cache storage options.

Full-key reads are cached in the Store-backed cache. Byte-range reads are
cached in a separate in-memory dictionary so that partial reads never
pollute the filesystem (or other persistent backend). Both caches share
the same ``max_size`` budget and LRU eviction policy.

Parameters
----------
store : Store
Expand Down Expand Up @@ -129,21 +140,21 @@ def with_read_only(self, read_only: bool = False) -> Self:
store._state = self._state
return store

def _is_key_fresh(self, key: str) -> bool:
"""Check if a cached key is still fresh based on max_age_seconds.
def _is_key_fresh(self, entry_key: _CacheEntryKey) -> bool:
"""Check if a cached entry is still fresh based on max_age_seconds.

Uses monotonic time for accurate elapsed time measurement.
"""
if self.max_age_seconds == "infinity":
return True
now = time.monotonic()
elapsed = now - self._state.key_insert_times.get(key, 0)
elapsed = now - self._state.key_insert_times.get(entry_key, 0)
return elapsed < self.max_age_seconds

async def _accommodate_value(self, value_size: int) -> None:
"""Ensure there is enough space in the cache for a new value.

Must be called while holding self._lock.
Must be called while holding self._state.lock.
"""
if self.max_size is None:
return
Expand All @@ -154,122 +165,168 @@ async def _accommodate_value(self, value_size: int) -> None:
lru_key = next(iter(self._state.cache_order))
await self._evict_key(lru_key)

async def _evict_key(self, key: str) -> None:
"""Evict a key from the cache.

Must be called while holding self._lock.
Updates size tracking atomically with deletion.
"""
try:
key_size = self._state.key_sizes.get(key, 0)

# Delete from cache store
await self._cache.delete(key)
async def _evict_key(self, entry_key: _CacheEntryKey) -> None:
"""Evict a cache entry.

# Update tracking after successful deletion
self._remove_from_tracking(key)
self._state.current_size = max(0, self._state.current_size - key_size)
self._state.evictions += 1
Must be called while holding self._state.lock.

logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size)
except Exception:
logger.exception("_evict_key: failed to evict key %s", key)
raise # Re-raise to signal eviction failure
For ``str`` keys the entry is deleted from the Store-backed cache.
For ``(str, ByteRequest)`` keys the entry is removed from the
in-memory range cache.
"""
key_size = self._state.key_sizes.get(entry_key, 0)

async def _cache_value(self, key: str, value: Buffer) -> None:
"""Cache a value with size tracking.
if isinstance(entry_key, str):
await self._cache.delete(entry_key)
else:
base_key, byte_range = entry_key
per_key = self._state.range_cache.get(base_key)
if per_key is not None:
per_key.pop(byte_range, None)
if not per_key:
del self._state.range_cache[base_key]

self._state.cache_order.pop(entry_key, None)
self._state.key_insert_times.pop(entry_key, None)
self._state.key_sizes.pop(entry_key, None)
self._state.current_size = max(0, self._state.current_size - key_size)
self._state.evictions += 1

async def _track_entry(self, entry_key: _CacheEntryKey, value: Buffer) -> bool:
"""Register *entry_key* in the shared size / LRU tracking.

Returns ``True`` if the entry was tracked, ``False`` if the value
exceeds ``max_size`` and was skipped. Callers should roll back any
data they already stored when this returns ``False``.

This method holds the lock for the entire operation to ensure atomicity.
"""
value_size = len(value)

# Check if value exceeds max size
if self.max_size is not None and value_size > self.max_size:
logger.warning(
"_cache_value: value size %d exceeds max_size %d, skipping cache",
value_size,
self.max_size,
)
return
return False

async with self._state.lock:
# If key already exists, subtract old size first
if key in self._state.key_sizes:
old_size = self._state.key_sizes[key]
if entry_key in self._state.key_sizes:
old_size = self._state.key_sizes[entry_key]
self._state.current_size -= old_size
logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size)

# Make room for the new value (this calls _evict_key_locked internally)
# Make room for the new value
await self._accommodate_value(value_size)

# Update tracking atomically
self._state.cache_order[key] = None # OrderedDict to track access order
self._state.cache_order[entry_key] = None
self._state.current_size += value_size
self._state.key_sizes[key] = value_size
self._state.key_insert_times[key] = time.monotonic()
self._state.key_sizes[entry_key] = value_size
self._state.key_insert_times[entry_key] = time.monotonic()

logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size)
return True

async def _update_access_order(self, key: str) -> None:
async def _update_access_order(self, entry_key: _CacheEntryKey) -> None:
"""Update the access order for LRU tracking."""
if key in self._state.cache_order:
if entry_key in self._state.cache_order:
async with self._state.lock:
# Move to end (most recently used)
self._state.cache_order.move_to_end(key)
self._state.cache_order.move_to_end(entry_key)

def _remove_from_tracking(self, key: str) -> None:
"""Remove a key from all tracking structures.
def _remove_from_tracking(self, entry_key: _CacheEntryKey) -> None:
"""Remove an entry from all tracking structures.

Must be called while holding self._state.lock.
"""
self._state.cache_order.pop(key, None)
self._state.key_insert_times.pop(key, None)
self._state.key_sizes.pop(key, None)
self._state.cache_order.pop(entry_key, None)
self._state.key_insert_times.pop(entry_key, None)
self._state.key_sizes.pop(entry_key, None)

def _invalidate_range_entries(self, key: str) -> None:
"""Remove all byte-range entries for *key* from the range cache and tracking.

Must be called while holding self._state.lock.
"""
per_key = self._state.range_cache.pop(key, None)
if per_key is not None:
for byte_range in per_key:
entry_key: _CacheEntryKey = (key, byte_range)
entry_size = self._state.key_sizes.pop(entry_key, 0)
self._state.cache_order.pop(entry_key, None)
self._state.key_insert_times.pop(entry_key, None)
self._state.current_size = max(0, self._state.current_size - entry_size)

# ------------------------------------------------------------------
# get helpers
# ------------------------------------------------------------------

async def _cache_miss(
self, key: str, byte_range: ByteRequest | None, result: Buffer | None
) -> None:
"""Handle a cache miss by storing or cleaning up after a source-store fetch."""
if result is None:
if byte_range is None:
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
else:
entry_key: _CacheEntryKey = (key, byte_range)
async with self._state.lock:
per_key = self._state.range_cache.get(key)
if per_key is not None:
per_key.pop(byte_range, None)
if not per_key:
del self._state.range_cache[key]
self._remove_from_tracking(entry_key)
else:
if byte_range is None:
await self._cache.set(key, result)
await self._track_entry(key, result)
else:
entry_key = (key, byte_range)
self._state.range_cache.setdefault(key, {})[byte_range] = result
tracked = await self._track_entry(entry_key, result)
if not tracked:
# Value too large for the cache — roll back the insertion
per_key = self._state.range_cache.get(key)
if per_key is not None:
per_key.pop(byte_range, None)
if not per_key:
del self._state.range_cache[key]

async def _get_try_cache(
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
"""Try to get data from cache first, falling back to source store."""
maybe_cached_result = await self._cache.get(key, prototype, byte_range)
if maybe_cached_result is not None:
logger.debug("_get_try_cache: key %s found in cache (HIT)", key)
self._state.hits += 1
# Update access order for LRU
await self._update_access_order(key)
return maybe_cached_result
if byte_range is None:
# Full-key read — use Store-backed cache
maybe_cached = await self._cache.get(key, prototype)
if maybe_cached is not None:
self._state.hits += 1
await self._update_access_order(key)
return maybe_cached
else:
logger.debug(
"_get_try_cache: key %s not found in cache (MISS), fetching from store", key
)
self._state.misses += 1
maybe_fresh_result = await super().get(key, prototype, byte_range)
if maybe_fresh_result is None:
# Key doesn't exist in source store
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
else:
# Cache the newly fetched value
await self._cache.set(key, maybe_fresh_result)
await self._cache_value(key, maybe_fresh_result)
return maybe_fresh_result
# Byte-range read — use in-memory range cache
entry_key: _CacheEntryKey = (key, byte_range)
per_key = self._state.range_cache.get(key)
if per_key is not None:
cached_buf = per_key.get(byte_range)
if cached_buf is not None:
self._state.hits += 1
await self._update_access_order(entry_key)
return cached_buf

# Cache miss — fetch from source store
self._state.misses += 1
result = await super().get(key, prototype, byte_range)
await self._cache_miss(key, byte_range, result)
return result

async def _get_no_cache(
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
) -> Buffer | None:
"""Get data directly from source store and update cache."""
self._state.misses += 1
maybe_fresh_result = await super().get(key, prototype, byte_range)
if maybe_fresh_result is None:
# Key doesn't exist in source, remove from cache and tracking
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
else:
logger.debug("_get_no_cache: key %s found in store, setting in cache", key)
await self._cache.set(key, maybe_fresh_result)
await self._cache_value(key, maybe_fresh_result)
return maybe_fresh_result
result = await super().get(key, prototype, byte_range)
await self._cache_miss(key, byte_range, result)
return result

async def get(
self,
Expand All @@ -294,11 +351,10 @@ async def get(
Buffer | None
The retrieved data, or None if not found
"""
if not self._is_key_fresh(key):
logger.debug("get: key %s is not fresh, fetching from store", key)
entry_key: _CacheEntryKey = (key, byte_range) if byte_range is not None else key
if not self._is_key_fresh(entry_key):
return await self._get_no_cache(key, prototype, byte_range)
else:
logger.debug("get: key %s is fresh, trying cache", key)
return await self._get_try_cache(key, prototype, byte_range)

async def set(self, key: str, value: Buffer) -> None:
Expand All @@ -312,14 +368,14 @@ async def set(self, key: str, value: Buffer) -> None:
value : Buffer
The data to store
"""
logger.debug("set: setting key %s in store", key)
await super().set(key, value)
# Invalidate all cached byte-range entries (source data changed)
async with self._state.lock:
self._invalidate_range_entries(key)
if self.cache_set_data:
logger.debug("set: setting key %s in cache", key)
await self._cache.set(key, value)
await self._cache_value(key, value)
await self._track_entry(key, value)
else:
logger.debug("set: deleting key %s from cache", key)
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
Expand All @@ -333,9 +389,10 @@ async def delete(self, key: str) -> None:
key : str
The key to delete
"""
logger.debug("delete: deleting key %s from store", key)
await super().delete(key)
logger.debug("delete: deleting key %s from cache", key)
# Invalidate all cached byte-range entries
async with self._state.lock:
self._invalidate_range_entries(key)
await self._cache.delete(key)
async with self._state.lock:
self._remove_from_tracking(key)
Expand Down Expand Up @@ -377,8 +434,8 @@ async def clear_cache(self) -> None:
self._state.key_insert_times.clear()
self._state.cache_order.clear()
self._state.key_sizes.clear()
self._state.range_cache.clear()
self._state.current_size = 0
logger.debug("clear_cache: cleared all cache data")

def __repr__(self) -> str:
"""Return string representation of the cache store."""
Expand Down
Loading