diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 87df89a683..f677c197dc 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -19,7 +19,7 @@ __all__ = ["ByteGetter", "ByteSetter", "Store", "set_or_delete"] -@dataclass +@dataclass(frozen=True, slots=True) class RangeByteRequest: """Request a specific byte range""" @@ -29,7 +29,7 @@ class RangeByteRequest: """The end of the byte range request (exclusive).""" -@dataclass +@dataclass(frozen=True, slots=True) class OffsetByteRequest: """Request all bytes starting from a given byte offset""" @@ -37,7 +37,7 @@ class OffsetByteRequest: """The byte offset for the offset range request.""" -@dataclass +@dataclass(frozen=True, slots=True) class SuffixByteRequest: """Request up to the last `n` bytes""" diff --git a/src/zarr/experimental/cache_store.py b/src/zarr/experimental/cache_store.py index 87adc90c83..1535b42f67 100644 --- a/src/zarr/experimental/cache_store.py +++ b/src/zarr/experimental/cache_store.py @@ -15,17 +15,23 @@ if TYPE_CHECKING: from zarr.core.buffer.core import Buffer, BufferPrototype +# A cache entry identifier. Plain ``str`` for full-key entries that live in +# the Store-backed cache; ``(str, ByteRequest)`` for byte-range entries that +# live in the in-memory range cache. +_CacheEntryKey = str | tuple[str, ByteRequest] + @dataclass(slots=True) class _CacheState: - cache_order: OrderedDict[str, None] = field(default_factory=OrderedDict) + cache_order: OrderedDict[_CacheEntryKey, None] = field(default_factory=OrderedDict) current_size: int = 0 - key_sizes: dict[str, int] = field(default_factory=dict) + key_sizes: dict[_CacheEntryKey, int] = field(default_factory=dict) lock: asyncio.Lock = field(default_factory=asyncio.Lock) hits: int = 0 misses: int = 0 evictions: int = 0 - key_insert_times: dict[str, float] = field(default_factory=dict) + key_insert_times: dict[_CacheEntryKey, float] = field(default_factory=dict) + range_cache: dict[str, dict[ByteRequest, Buffer]] = field(default_factory=dict) class CacheStore(WrapperStore[Store]): @@ -36,6 +42,11 @@ class CacheStore(WrapperStore[Store]): as the cache backend. This provides persistent caching capabilities with time-based expiration, size-based eviction, and flexible cache storage options. + Full-key reads are cached in the Store-backed cache. Byte-range reads are + cached in a separate in-memory dictionary so that partial reads never + pollute the filesystem (or other persistent backend). Both caches share + the same ``max_size`` budget and LRU eviction policy. + Parameters ---------- store : Store @@ -129,21 +140,21 @@ def with_read_only(self, read_only: bool = False) -> Self: store._state = self._state return store - def _is_key_fresh(self, key: str) -> bool: - """Check if a cached key is still fresh based on max_age_seconds. + def _is_key_fresh(self, entry_key: _CacheEntryKey) -> bool: + """Check if a cached entry is still fresh based on max_age_seconds. Uses monotonic time for accurate elapsed time measurement. """ if self.max_age_seconds == "infinity": return True now = time.monotonic() - elapsed = now - self._state.key_insert_times.get(key, 0) + elapsed = now - self._state.key_insert_times.get(entry_key, 0) return elapsed < self.max_age_seconds async def _accommodate_value(self, value_size: int) -> None: """Ensure there is enough space in the cache for a new value. - Must be called while holding self._lock. + Must be called while holding self._state.lock. """ if self.max_size is None: return @@ -154,30 +165,39 @@ async def _accommodate_value(self, value_size: int) -> None: lru_key = next(iter(self._state.cache_order)) await self._evict_key(lru_key) - async def _evict_key(self, key: str) -> None: - """Evict a key from the cache. - - Must be called while holding self._lock. - Updates size tracking atomically with deletion. - """ - try: - key_size = self._state.key_sizes.get(key, 0) - - # Delete from cache store - await self._cache.delete(key) + async def _evict_key(self, entry_key: _CacheEntryKey) -> None: + """Evict a cache entry. - # Update tracking after successful deletion - self._remove_from_tracking(key) - self._state.current_size = max(0, self._state.current_size - key_size) - self._state.evictions += 1 + Must be called while holding self._state.lock. - logger.debug("_evict_key: evicted key %s, freed %d bytes", key, key_size) - except Exception: - logger.exception("_evict_key: failed to evict key %s", key) - raise # Re-raise to signal eviction failure + For ``str`` keys the entry is deleted from the Store-backed cache. + For ``(str, ByteRequest)`` keys the entry is removed from the + in-memory range cache. + """ + key_size = self._state.key_sizes.get(entry_key, 0) - async def _cache_value(self, key: str, value: Buffer) -> None: - """Cache a value with size tracking. + if isinstance(entry_key, str): + await self._cache.delete(entry_key) + else: + base_key, byte_range = entry_key + per_key = self._state.range_cache.get(base_key) + if per_key is not None: + per_key.pop(byte_range, None) + if not per_key: + del self._state.range_cache[base_key] + + self._state.cache_order.pop(entry_key, None) + self._state.key_insert_times.pop(entry_key, None) + self._state.key_sizes.pop(entry_key, None) + self._state.current_size = max(0, self._state.current_size - key_size) + self._state.evictions += 1 + + async def _track_entry(self, entry_key: _CacheEntryKey, value: Buffer) -> bool: + """Register *entry_key* in the shared size / LRU tracking. + + Returns ``True`` if the entry was tracked, ``False`` if the value + exceeds ``max_size`` and was skipped. Callers should roll back any + data they already stored when this returns ``False``. This method holds the lock for the entire operation to ensure atomicity. """ @@ -185,91 +205,128 @@ async def _cache_value(self, key: str, value: Buffer) -> None: # Check if value exceeds max size if self.max_size is not None and value_size > self.max_size: - logger.warning( - "_cache_value: value size %d exceeds max_size %d, skipping cache", - value_size, - self.max_size, - ) - return + return False async with self._state.lock: # If key already exists, subtract old size first - if key in self._state.key_sizes: - old_size = self._state.key_sizes[key] + if entry_key in self._state.key_sizes: + old_size = self._state.key_sizes[entry_key] self._state.current_size -= old_size - logger.debug("_cache_value: updating existing key %s, old size %d", key, old_size) - # Make room for the new value (this calls _evict_key_locked internally) + # Make room for the new value await self._accommodate_value(value_size) # Update tracking atomically - self._state.cache_order[key] = None # OrderedDict to track access order + self._state.cache_order[entry_key] = None self._state.current_size += value_size - self._state.key_sizes[key] = value_size - self._state.key_insert_times[key] = time.monotonic() + self._state.key_sizes[entry_key] = value_size + self._state.key_insert_times[entry_key] = time.monotonic() - logger.debug("_cache_value: cached key %s with size %d bytes", key, value_size) + return True - async def _update_access_order(self, key: str) -> None: + async def _update_access_order(self, entry_key: _CacheEntryKey) -> None: """Update the access order for LRU tracking.""" - if key in self._state.cache_order: + if entry_key in self._state.cache_order: async with self._state.lock: - # Move to end (most recently used) - self._state.cache_order.move_to_end(key) + self._state.cache_order.move_to_end(entry_key) - def _remove_from_tracking(self, key: str) -> None: - """Remove a key from all tracking structures. + def _remove_from_tracking(self, entry_key: _CacheEntryKey) -> None: + """Remove an entry from all tracking structures. Must be called while holding self._state.lock. """ - self._state.cache_order.pop(key, None) - self._state.key_insert_times.pop(key, None) - self._state.key_sizes.pop(key, None) + self._state.cache_order.pop(entry_key, None) + self._state.key_insert_times.pop(entry_key, None) + self._state.key_sizes.pop(entry_key, None) + + def _invalidate_range_entries(self, key: str) -> None: + """Remove all byte-range entries for *key* from the range cache and tracking. + + Must be called while holding self._state.lock. + """ + per_key = self._state.range_cache.pop(key, None) + if per_key is not None: + for byte_range in per_key: + entry_key: _CacheEntryKey = (key, byte_range) + entry_size = self._state.key_sizes.pop(entry_key, 0) + self._state.cache_order.pop(entry_key, None) + self._state.key_insert_times.pop(entry_key, None) + self._state.current_size = max(0, self._state.current_size - entry_size) + + # ------------------------------------------------------------------ + # get helpers + # ------------------------------------------------------------------ + + async def _cache_miss( + self, key: str, byte_range: ByteRequest | None, result: Buffer | None + ) -> None: + """Handle a cache miss by storing or cleaning up after a source-store fetch.""" + if result is None: + if byte_range is None: + await self._cache.delete(key) + async with self._state.lock: + self._remove_from_tracking(key) + else: + entry_key: _CacheEntryKey = (key, byte_range) + async with self._state.lock: + per_key = self._state.range_cache.get(key) + if per_key is not None: + per_key.pop(byte_range, None) + if not per_key: + del self._state.range_cache[key] + self._remove_from_tracking(entry_key) + else: + if byte_range is None: + await self._cache.set(key, result) + await self._track_entry(key, result) + else: + entry_key = (key, byte_range) + self._state.range_cache.setdefault(key, {})[byte_range] = result + tracked = await self._track_entry(entry_key, result) + if not tracked: + # Value too large for the cache — roll back the insertion + per_key = self._state.range_cache.get(key) + if per_key is not None: + per_key.pop(byte_range, None) + if not per_key: + del self._state.range_cache[key] async def _get_try_cache( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: """Try to get data from cache first, falling back to source store.""" - maybe_cached_result = await self._cache.get(key, prototype, byte_range) - if maybe_cached_result is not None: - logger.debug("_get_try_cache: key %s found in cache (HIT)", key) - self._state.hits += 1 - # Update access order for LRU - await self._update_access_order(key) - return maybe_cached_result + if byte_range is None: + # Full-key read — use Store-backed cache + maybe_cached = await self._cache.get(key, prototype) + if maybe_cached is not None: + self._state.hits += 1 + await self._update_access_order(key) + return maybe_cached else: - logger.debug( - "_get_try_cache: key %s not found in cache (MISS), fetching from store", key - ) - self._state.misses += 1 - maybe_fresh_result = await super().get(key, prototype, byte_range) - if maybe_fresh_result is None: - # Key doesn't exist in source store - await self._cache.delete(key) - async with self._state.lock: - self._remove_from_tracking(key) - else: - # Cache the newly fetched value - await self._cache.set(key, maybe_fresh_result) - await self._cache_value(key, maybe_fresh_result) - return maybe_fresh_result + # Byte-range read — use in-memory range cache + entry_key: _CacheEntryKey = (key, byte_range) + per_key = self._state.range_cache.get(key) + if per_key is not None: + cached_buf = per_key.get(byte_range) + if cached_buf is not None: + self._state.hits += 1 + await self._update_access_order(entry_key) + return cached_buf + + # Cache miss — fetch from source store + self._state.misses += 1 + result = await super().get(key, prototype, byte_range) + await self._cache_miss(key, byte_range, result) + return result async def _get_no_cache( self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None ) -> Buffer | None: """Get data directly from source store and update cache.""" self._state.misses += 1 - maybe_fresh_result = await super().get(key, prototype, byte_range) - if maybe_fresh_result is None: - # Key doesn't exist in source, remove from cache and tracking - await self._cache.delete(key) - async with self._state.lock: - self._remove_from_tracking(key) - else: - logger.debug("_get_no_cache: key %s found in store, setting in cache", key) - await self._cache.set(key, maybe_fresh_result) - await self._cache_value(key, maybe_fresh_result) - return maybe_fresh_result + result = await super().get(key, prototype, byte_range) + await self._cache_miss(key, byte_range, result) + return result async def get( self, @@ -294,11 +351,10 @@ async def get( Buffer | None The retrieved data, or None if not found """ - if not self._is_key_fresh(key): - logger.debug("get: key %s is not fresh, fetching from store", key) + entry_key: _CacheEntryKey = (key, byte_range) if byte_range is not None else key + if not self._is_key_fresh(entry_key): return await self._get_no_cache(key, prototype, byte_range) else: - logger.debug("get: key %s is fresh, trying cache", key) return await self._get_try_cache(key, prototype, byte_range) async def set(self, key: str, value: Buffer) -> None: @@ -312,14 +368,14 @@ async def set(self, key: str, value: Buffer) -> None: value : Buffer The data to store """ - logger.debug("set: setting key %s in store", key) await super().set(key, value) + # Invalidate all cached byte-range entries (source data changed) + async with self._state.lock: + self._invalidate_range_entries(key) if self.cache_set_data: - logger.debug("set: setting key %s in cache", key) await self._cache.set(key, value) - await self._cache_value(key, value) + await self._track_entry(key, value) else: - logger.debug("set: deleting key %s from cache", key) await self._cache.delete(key) async with self._state.lock: self._remove_from_tracking(key) @@ -333,9 +389,10 @@ async def delete(self, key: str) -> None: key : str The key to delete """ - logger.debug("delete: deleting key %s from store", key) await super().delete(key) - logger.debug("delete: deleting key %s from cache", key) + # Invalidate all cached byte-range entries + async with self._state.lock: + self._invalidate_range_entries(key) await self._cache.delete(key) async with self._state.lock: self._remove_from_tracking(key) @@ -377,8 +434,8 @@ async def clear_cache(self) -> None: self._state.key_insert_times.clear() self._state.cache_order.clear() self._state.key_sizes.clear() + self._state.range_cache.clear() self._state.current_size = 0 - logger.debug("clear_cache: cleared all cache data") def __repr__(self) -> str: """Return string representation of the cache store.""" diff --git a/tests/test_experimental/test_cache_store.py b/tests/test_experimental/test_cache_store.py index 50d3d9506b..fc17ccd5e1 100644 --- a/tests/test_experimental/test_cache_store.py +++ b/tests/test_experimental/test_cache_store.py @@ -7,7 +7,7 @@ import pytest -from zarr.abc.store import Store +from zarr.abc.store import RangeByteRequest, Store, SuffixByteRequest from zarr.core.buffer.core import default_buffer_prototype from zarr.core.buffer.cpu import Buffer as CPUBuffer from zarr.experimental.cache_store import CacheStore @@ -581,7 +581,7 @@ async def test_get_no_cache_delete_tracking(self) -> None: # First, add key to cache tracking but not to source test_data = CPUBuffer.from_bytes(b"test data") await cache_store.set("phantom_key", test_data) - await cached_store._cache_value("phantom_key", test_data) + await cached_store._track_entry("phantom_key", test_data) # Verify it's in tracking assert "phantom_key" in cached_store._state.cache_order @@ -778,17 +778,20 @@ async def test_all_tracked_keys_exist_in_cache_store(self) -> None: data = CPUBuffer.from_bytes(b"x" * 50) await cached_store.set(f"key_{i}", data) - # Every key in tracking should exist in cache_store - for key in cached_store._state.cache_order: - assert await cache_store.exists(key), ( - f"Key '{key}' is tracked but doesn't exist in cache_store" - ) - - # Every key in _key_sizes should exist in cache_store - for key in cached_store._state.key_sizes: - assert await cache_store.exists(key), ( - f"Key '{key}' has size tracked but doesn't exist in cache_store" - ) + # Every str key in tracking should exist in cache_store + # (tuple keys are byte-range entries stored in-memory, not in the Store) + for entry_key in cached_store._state.cache_order: + if isinstance(entry_key, str): + assert await cache_store.exists(entry_key), ( + f"Key '{entry_key}' is tracked but doesn't exist in cache_store" + ) + + # Every str key in _key_sizes should exist in cache_store + for entry_key in cached_store._state.key_sizes: + if isinstance(entry_key, str): + assert await cache_store.exists(entry_key), ( + f"Key '{entry_key}' has size tracked but doesn't exist in cache_store" + ) # Additional coverage tests for 100% coverage @@ -908,3 +911,139 @@ async def test_cache_stats_zero_division_protection(self) -> None: stats = cached_store.cache_stats() assert stats["hit_rate"] == 0.0 assert stats["total_requests"] == 0 + + async def test_byte_range_does_not_corrupt_cache(self) -> None: + """Test that fetching a byte range does not store partial data under the full key. + + Reproduces https://github.com/zarr-developers/zarr-python/issues/3690: + when a byte-range read populates the cache, subsequent reads of different + ranges (or the full key) return wrong data. + """ + source_store = MemoryStore() + cache_store = MemoryStore() + cached_store = CacheStore(store=source_store, cache_store=cache_store) + + full_data = b"bar baz" + await source_store.set("foo", CPUBuffer.from_bytes(full_data)) + + proto = default_buffer_prototype() + + # First read: byte range [0, 3) -> b"bar" + bar = await cached_store.get("foo", proto, byte_range=RangeByteRequest(0, 3)) + assert bar is not None + assert bar.to_bytes() == b"bar" + + # Second read: different byte range [4, 7) -> b"baz" + baz = await cached_store.get("foo", proto, byte_range=RangeByteRequest(4, 7)) + assert baz is not None + assert baz.to_bytes() == b"baz" + + # Third read: full key -> full data + full = await cached_store.get("foo", proto) + assert full is not None + assert full.to_bytes() == full_data + + async def test_full_read_then_byte_range(self) -> None: + """Test that a cached full read correctly serves subsequent byte-range requests.""" + source_store = MemoryStore() + cache_store = MemoryStore() + cached_store = CacheStore(store=source_store, cache_store=cache_store) + + full_data = b"hello world" + await source_store.set("key", CPUBuffer.from_bytes(full_data)) + + proto = default_buffer_prototype() + + # Full read populates cache + full = await cached_store.get("key", proto) + assert full is not None + assert full.to_bytes() == full_data + + # Byte-range reads should return the correct slices + part = await cached_store.get("key", proto, byte_range=RangeByteRequest(0, 5)) + assert part is not None + assert part.to_bytes() == b"hello" + + part2 = await cached_store.get("key", proto, byte_range=RangeByteRequest(6, 11)) + assert part2 is not None + assert part2.to_bytes() == b"world" + + suffix = await cached_store.get("key", proto, byte_range=SuffixByteRequest(5)) + assert suffix is not None + assert suffix.to_bytes() == b"world" + + async def test_byte_range_set_then_read(self) -> None: + """Test that data written via set() can be read back with byte ranges.""" + source_store = MemoryStore() + cache_store = MemoryStore() + cached_store = CacheStore(store=source_store, cache_store=cache_store) + + full_data = b"abcdefghij" + await cached_store.set("key", CPUBuffer.from_bytes(full_data)) + + proto = default_buffer_prototype() + + # Byte-range reads from the cached data + mid = await cached_store.get("key", proto, byte_range=RangeByteRequest(3, 7)) + assert mid is not None + assert mid.to_bytes() == b"defg" + + # Full read should still work + full = await cached_store.get("key", proto) + assert full is not None + assert full.to_bytes() == full_data + + async def test_set_invalidates_cached_byte_ranges(self) -> None: + """Test that set() invalidates previously cached byte-range entries.""" + source_store = MemoryStore() + cache_store = MemoryStore() + cached_store = CacheStore(store=source_store, cache_store=cache_store) + + proto = default_buffer_prototype() + + # Populate source and cache some byte ranges + await source_store.set("key", CPUBuffer.from_bytes(b"old data!!")) + r1 = await cached_store.get("key", proto, byte_range=RangeByteRequest(0, 3)) + assert r1 is not None + assert r1.to_bytes() == b"old" + + # Byte-range entry should be in range_cache + assert ("key", RangeByteRequest(0, 3)) in cached_store._state.cache_order + + # Overwrite via set() — range entries must be invalidated + await cached_store.set("key", CPUBuffer.from_bytes(b"NEW DATA!!")) + + # The old range entry should be gone from tracking and range_cache + assert ("key", RangeByteRequest(0, 3)) not in cached_store._state.cache_order + assert "key" not in cached_store._state.range_cache + + # A fresh byte-range read should return the new data + r2 = await cached_store.get("key", proto, byte_range=RangeByteRequest(0, 3)) + assert r2 is not None + assert r2.to_bytes() == b"NEW" + + async def test_delete_invalidates_cached_byte_ranges(self) -> None: + """Test that delete() removes previously cached byte-range entries.""" + source_store = MemoryStore() + cache_store = MemoryStore() + cached_store = CacheStore(store=source_store, cache_store=cache_store) + + proto = default_buffer_prototype() + + # Populate and cache a byte range + await source_store.set("key", CPUBuffer.from_bytes(b"hello world")) + r = await cached_store.get("key", proto, byte_range=RangeByteRequest(0, 5)) + assert r is not None + assert r.to_bytes() == b"hello" + + assert ("key", RangeByteRequest(0, 5)) in cached_store._state.cache_order + + # Delete the key — range entries must be cleaned up + await cached_store.delete("key") + + assert ("key", RangeByteRequest(0, 5)) not in cached_store._state.cache_order + assert "key" not in cached_store._state.range_cache + + # Key is gone from source + result = await cached_store.get("key", proto) + assert result is None