diff --git a/changes/3907.feature.md b/changes/3907.feature.md new file mode 100644 index 0000000000..e9fa75d632 --- /dev/null +++ b/changes/3907.feature.md @@ -0,0 +1,5 @@ +Add the `SupportsSetRange` protocol for stores that support writing to a byte range within an existing value, implemented by `LocalStore` and `MemoryStore`. This is necessary to support in-place writes of sharded arrays (e.g. writing a single subchunk without rewriting the entire shard). + +Byte-range writes are exposed as an opt-in protocol rather than a method on the `Store` ABC. Only a few stores can perform them natively, and most cannot. A universal method with a read-modify-write fallback (as in the Rust `zarrs` crate) would let every store participate, but for the motivating use case that fallback would silently rewrite an entire shard, defeating the purpose. The opt-in protocol keeps the cost model honest and keeps `set_range` out of the signatures of stores that will never support it; any fallback strategy is left to the caller (the sharding codec). Stores satisfy the protocol structurally, so `GpuMemoryStore` (which has no use case for in-place GPU byte-range writes) disclaims it and is correctly reported as unsupported by `isinstance`. + +It is entirely the caller's responsibility to ensure consistency: concurrent writes to overlapping ranges are order-dependent, `set_range` racing against `set`/`delete` is a race, and writes are not guaranteed to be atomic with respect to a process crash. A write that does not fit within the existing value raises `ValueError` consistently across implementations. diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index ab58acf59f..0384f391e8 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -23,6 +23,7 @@ "Store", "SupportsDeleteSync", "SupportsGetSync", + "SupportsSetRange", "SupportsSetSync", "SupportsSyncStore", "set_or_delete", @@ -787,6 +788,62 @@ async def delete(self) -> None: ... async def set_if_not_exists(self, default: Buffer) -> None: ... +# Design note: byte-range writes are exposed as an opt-in protocol rather than a +# method on the `Store` ABC. Only a few stores can do them natively (`LocalStore`, +# `MemoryStore`); most (cloud, zip, read-only) cannot. A universal `Store.set_range` +# with a read-modify-write fallback (as in the Rust `zarrs` crate's +# `WritableStorageTraits::set_partial` + `supports_set_partial`) would let every +# store participate, but for our motivating use case — writing one subchunk without +# rewriting the whole shard — that fallback is a footgun: it would silently rewrite +# an entire (possibly multi-GB) shard, defeating the purpose while appearing to +# succeed. An opt-in protocol instead keeps the cost model honest (a store either +# supports cheap ranged writes or doesn't advertise the capability at all) and keeps +# `set_range` out of the signatures of stores that will never support it. +# +# Stores satisfy this protocol *structurally* (by defining the methods), not by +# nominal inheritance, so a subclass can disclaim it by setting the methods to `None` +# (see `GpuMemoryStore`). Any read-modify-write fallback strategy belongs in the +# caller (the sharding codec), which already has to decide between in-place and +# buffer-and-rewrite — mirroring the zarrs layering (storage writes bytes, codec owns +# strategy) without making every store carry the method. If broad-backend partial +# encoding is wanted later, adding a `supports_set_range()` capability flag plus a +# codec-level fallback is an additive change that does not require retrofitting stores. +@runtime_checkable +class SupportsSetRange(Protocol): + """Protocol for stores that support writing to a byte range within an existing value. + + Overwrites `len(value)` bytes starting at byte offset `start` within the + existing stored value for `key`. The key must already exist and the write + must fit within the existing value (i.e., `start + len(value) <= len(existing)`); + a write that does not fit raises `ValueError`. + + Concurrency and atomicity + ------------------------- + **It is entirely the caller's responsibility to ensure consistency.** Any + coordination needed to keep stored values consistent must be arranged by the + caller. In particular: + + - Concurrent `set_range` calls that write to **disjoint** byte ranges of the + same key are safe. + - Concurrent `set_range` calls that write to **overlapping** ranges of the same + key have order-dependent, unspecified results. The caller must serialize them. + - A `set_range` racing against a `set` or `delete` on the same key is a race + condition, just as concurrent `set` calls are. The caller must serialize these. + - Writes are **not** guaranteed to be atomic with respect to a process crash: + a crash mid-write may leave the value partially updated. The caller is + responsible for any durability or recovery guarantees it requires. + + What an implementation does to honor (or fall short of) this contract — locking, + atomic replacement, and so on — is documented on the implementing store, not here. + The intended consumer (the sharding codec writing inner chunks of deterministic + compressed size) coordinates writes so that they target disjoint ranges. + """ + + async def set_range(self, key: str, value: Buffer, start: int) -> None: ... + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: ... + + @runtime_checkable class SupportsGetSync(Protocol): def get_sync( diff --git a/src/zarr/storage/_fsspec.py b/src/zarr/storage/_fsspec.py index 29201a6fee..0493e2cf74 100644 --- a/src/zarr/storage/_fsspec.py +++ b/src/zarr/storage/_fsspec.py @@ -266,7 +266,14 @@ def from_url( from fsspec.core import url_to_fs opts = storage_options or {} - opts = {"asynchronous": True, **opts} + # ``skip_instance_cache=True`` forces a fresh filesystem instance instead of + # an fsspec instance-cached one. Without it, two ``from_url`` calls with the + # same URL/options receive the *same* cached ``AsyncFileSystem``; closing one + # store (which we mark as owning the fs) would tear down the shared aiohttp + # session out from under the other store — and any other fsspec consumer in + # the process. By skipping the cache we own an instance no one else shares, so + # ``close()`` is safe. + opts = {"asynchronous": True, "skip_instance_cache": True, **opts} fs, path = url_to_fs(url, **opts) if not fs.async_impl: diff --git a/src/zarr/storage/_local.py b/src/zarr/storage/_local.py index 038de4fef8..3c802db9be 100644 --- a/src/zarr/storage/_local.py +++ b/src/zarr/storage/_local.py @@ -6,6 +6,8 @@ import os import shutil import sys +import threading +import time import uuid from pathlib import Path from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Self @@ -20,6 +22,7 @@ from zarr.core.buffer import Buffer from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import AccessModeLiteral, concurrent_map +from zarr.storage._utils import _check_set_range_bounds if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterable, Iterator @@ -58,6 +61,18 @@ def _safe_move(src: Path, dst: Path) -> None: os.unlink(src) +_LOCK_POLL_INTERVAL = 0.01 # seconds between lock-file existence checks +_LOCK_STALE_TIMEOUT = 60.0 # seconds before an abandoned lock file is reclaimed + + +def _is_stale_lock(lock_path: Path) -> bool: + """Return True if lock_path either doesn't exist or is older than _LOCK_STALE_TIMEOUT.""" + try: + return time.time() - lock_path.stat().st_mtime > _LOCK_STALE_TIMEOUT + except FileNotFoundError: + return True + + @contextlib.contextmanager def _atomic_write( path: Path, @@ -77,6 +92,20 @@ def _atomic_write( raise +def _put_range(path: Path, value: Buffer, start: int) -> None: + """Write bytes at a specific offset within an existing file.""" + view = value.as_buffer_like() + with path.open("r+b") as f: + # Validate bounds before writing: a bare seek+write would silently extend the + # file (zero-filling any gap), but the SupportsSetRange contract requires the + # write to fit within the existing value, so we fail consistently with + # MemoryStore instead. + existing_length = f.seek(0, os.SEEK_END) + _check_set_range_bounds(existing_length, start, len(value)) + f.seek(start) + f.write(view) + + def _put(path: Path, value: Buffer, exclusive: bool = False) -> int: path.parent.mkdir(parents=True, exist_ok=True) # write takes any object supporting the buffer protocol @@ -109,6 +138,8 @@ class LocalStore(Store): supports_listing: bool = True root: Path + _key_locks: dict[str, asyncio.Lock] + _key_locks_sync: dict[str, threading.Lock] def __init__(self, root: Path | str, *, read_only: bool = False) -> None: super().__init__(read_only=read_only) @@ -119,6 +150,8 @@ def __init__(self, root: Path | str, *, read_only: bool = False) -> None: f"'root' must be a string or Path instance. Got an instance of {type(root)} instead." ) self.root = root + self._key_locks = {} + self._key_locks_sync = {} def with_read_only(self, read_only: bool = False) -> Self: # docstring inherited @@ -292,6 +325,82 @@ async def _set(self, key: str, value: Buffer, exclusive: bool = False) -> None: path = self.root / key await asyncio.to_thread(_put, path, value, exclusive=exclusive) + async def set_range(self, key: str, value: Buffer, start: int) -> None: + if not self._is_open: + await self._open() + self._check_writable() + path = self.root / key + lock_path = path.with_name(path.name + ".__lock__") + in_process_lock = self._key_locks.setdefault(key, asyncio.Lock()) + + # Acquire the file lock (steps 1-5 from the concurrency plan). + while True: + # Step 1: spin-wait until no lock file is present (or it is stale). + while await asyncio.to_thread(lock_path.exists): + if await asyncio.to_thread(_is_stale_lock, lock_path): + break + await asyncio.sleep(_LOCK_POLL_INTERVAL) + + # Steps 2-5: serialise the rename under an in-process lock so that + # only one coroutine per process attempts the atomic file move at a time. + acquired = False + async with in_process_lock: + # Step 3: re-check after acquiring the in-process lock. + if not await asyncio.to_thread(lock_path.exists): + try: + # Step 4: atomic rename — raises FileExistsError if another + # process grabbed the lock between steps 3 and 4. + await asyncio.to_thread(_safe_move, path, lock_path) + acquired = True + except FileExistsError: + pass + # Step 5: in-process lock released on context exit. + + if acquired: + break + + # Step 6: perform the partial write on the lock file. + try: + await asyncio.to_thread(_put_range, lock_path, value, start) + finally: + # Steps 7-9: re-acquire in-process lock, rename lock file back, release. + async with in_process_lock: + await asyncio.to_thread(lock_path.replace, path) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + self._ensure_open_sync() + self._check_writable() + path = self.root / key + lock_path = path.with_name(path.name + ".__lock__") + in_process_lock = self._key_locks_sync.setdefault(key, threading.Lock()) + + # Acquire the file lock (same double-checked pattern as the async path). + while True: + # Step 1: spin-wait. + while lock_path.exists(): + if _is_stale_lock(lock_path): + break + time.sleep(_LOCK_POLL_INTERVAL) + + acquired = False + with in_process_lock: + if not lock_path.exists(): + try: + _safe_move(path, lock_path) + acquired = True + except FileExistsError: + pass + + if acquired: + break + + # Partial write, then release. + try: + _put_range(lock_path, value, start) + finally: + with in_process_lock: + lock_path.replace(path) + async def delete(self, key: str) -> None: """ Remove a key from the store. diff --git a/src/zarr/storage/_memory.py b/src/zarr/storage/_memory.py index 121fcdab7f..763453de0a 100644 --- a/src/zarr/storage/_memory.py +++ b/src/zarr/storage/_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import os import threading import weakref @@ -11,6 +12,7 @@ from zarr.core.buffer.core import default_buffer_prototype from zarr.core.common import concurrent_map from zarr.storage._utils import ( + _check_set_range_bounds, _join_paths, _normalize_byte_range_index, normalize_path, @@ -49,6 +51,8 @@ class MemoryStore(Store): supports_listing: bool = True _store_dict: MutableMapping[str, Buffer] + _key_locks: dict[str, asyncio.Lock] + _key_locks_sync: dict[str, threading.Lock] def __init__( self, @@ -60,6 +64,8 @@ def __init__( if store_dict is None: store_dict = {} self._store_dict = store_dict + self._key_locks = {} + self._key_locks_sync = {} def with_read_only(self, read_only: bool = False) -> MemoryStore: # docstring inherited @@ -194,6 +200,31 @@ async def delete(self, key: str) -> None: except KeyError: logger.debug("Key %s does not exist.", key) + def _set_range_impl(self, key: str, value: Buffer, start: int) -> None: + buf = self._store_dict[key] + target = buf.as_numpy_array() + _check_set_range_bounds(len(target), start, len(value)) + if not target.flags.writeable: + target = target.copy() + self._store_dict[key] = buf.__class__(target) + source = value.as_numpy_array() + target[start : start + len(source)] = source + + async def set_range(self, key: str, value: Buffer, start: int) -> None: + self._check_writable() + await self._ensure_open() + lock = self._key_locks.setdefault(key, asyncio.Lock()) + async with lock: + self._set_range_impl(key, value, start) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + self._check_writable() + if not self._is_open: + self._is_open = True + lock = self._key_locks_sync.setdefault(key, threading.Lock()) + with lock: + self._set_range_impl(key, value, start) + async def list(self) -> AsyncIterator[str]: # docstring inherited for key in self._store_dict: @@ -537,6 +568,19 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None gpu_value = value if isinstance(value, gpu.Buffer) else gpu.Buffer.from_buffer(value) await super().set(key, gpu_value, byte_range=byte_range) + # ``GpuMemoryStore`` deliberately does not support byte-range writes, so it must + # not satisfy ``SupportsSetRange``. The inherited ``MemoryStore`` implementation + # mutates a *host* copy of the GPU buffer (via ``as_numpy_array``) and would + # silently lose the write, and there is no use case for in-place byte-range writes + # into GPU memory (the intended ``set_range`` consumer targets cpu/local storage). + # Disclaiming the inherited methods by setting them to ``None`` makes + # ``isinstance(gpu_store, SupportsSetRange)`` return ``False`` (``runtime_checkable`` + # treats a ``None`` attribute as "method absent"). This works because + # ``MemoryStore`` satisfies the protocol structurally rather than by nominal + # inheritance; mirrors the ``__hash__ = None`` idiom. + set_range = None # type: ignore[assignment] + set_range_sync = None # type: ignore[assignment] + # ----------------------------------------------------------------------------- # ManagedMemoryStore and its registry @@ -570,6 +614,13 @@ def __init__(self) -> None: self._registry: weakref.WeakValueDictionary[str, _ManagedStoreDict] = ( weakref.WeakValueDictionary() ) + # set_range per-key lock dicts, shared by name so every ManagedMemoryStore + # bound to the same backing dict serializes set_range against the others (the + # registry shares the data, so it must share the locks too). Keyed by name + # (_ManagedStoreDict is an unhashable dict subclass) and pruned when the + # corresponding dict has been collected. + self._key_locks: dict[str, dict[str, asyncio.Lock]] = {} + self._key_locks_sync: dict[str, dict[str, threading.Lock]] = {} self._counter = 0 self._lock = threading.Lock() @@ -639,6 +690,24 @@ def get(self, name: str) -> _ManagedStoreDict | None: """ return self._registry.get(name) + def get_key_locks(self, name: str) -> tuple[dict[str, asyncio.Lock], dict[str, threading.Lock]]: + """ + Get the shared set_range per-key lock dicts for a store name. + + All ManagedMemoryStore instances resolving to the same name get the same lock + dicts, so their set_range calls serialize against each other. Created on first + use; stale entries (whose backing dict has been collected) are pruned. + """ + with self._lock: + stale = [n for n in self._key_locks if n not in self._registry] + for n in stale: + del self._key_locks[n] + self._key_locks_sync.pop(n, None) + return ( + self._key_locks.setdefault(name, {}), + self._key_locks_sync.setdefault(name, {}), + ) + _managed_store_dict_registry = _ManagedStoreDictRegistry() @@ -709,6 +778,12 @@ def __init__(self, name: str | None = None, *, path: str = "", read_only: bool = # Get or create a managed dict from the registry self._store_dict, self._name = _managed_store_dict_registry.get_or_create(name) self.path = normalize_path(path) + # Share the per-key set_range locks with every other store backed by the same + # dict, so concurrent set_range from different handles to the same name actually + # serialize. + self._key_locks, self._key_locks_sync = _managed_store_dict_registry.get_key_locks( + self._name + ) def __str__(self) -> str: return _join_paths([f"memory://{self._name}", self.path]) @@ -744,6 +819,7 @@ def _from_managed_dict( store._store_dict = managed_dict store._name = name store.path = normalize_path(path) + store._key_locks, store._key_locks_sync = _managed_store_dict_registry.get_key_locks(name) return store def with_read_only(self, read_only: bool = False) -> ManagedMemoryStore: @@ -827,6 +903,14 @@ async def set_if_not_exists(self, key: str, value: Buffer) -> None: # docstring inherited return await super().set_if_not_exists(_join_paths([self.path, key]), value) + async def set_range(self, key: str, value: Buffer, start: int) -> None: + # docstring inherited + return await super().set_range(_join_paths([self.path, key]), value, start) + + def set_range_sync(self, key: str, value: Buffer, start: int) -> None: + # docstring inherited + return super().set_range_sync(_join_paths([self.path, key]), value, start) + async def delete(self, key: str) -> None: # docstring inherited return await super().delete(_join_paths([self.path, key])) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 1f8e9b0a29..ef7bdf296e 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -160,6 +160,23 @@ def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) -> return (start, stop) +def _check_set_range_bounds(existing_length: int, start: int, value_length: int) -> None: + """ + Validate that a ``set_range`` write fits within an existing value. + + Stores implementing ``SupportsSetRange`` use this so the out-of-bounds case fails + the same way everywhere (a clear ``ValueError``) rather than silently extending the + value (as a file seek+write would) or raising an opaque numpy shape error. + """ + if start < 0: + raise ValueError(f"set_range start must be non-negative, got {start}.") + if start + value_length > existing_length: + raise ValueError( + f"set_range write of {value_length} bytes at offset {start} does not fit " + f"within the existing value of length {existing_length}." + ) + + def _join_paths(paths: Iterable[str]) -> str: """ Filter out instances of '' and join the remaining strings with '/'. diff --git a/tests/test_store/test_fsspec.py b/tests/test_store/test_fsspec.py index 8006470174..fa8426448d 100644 --- a/tests/test_store/test_fsspec.py +++ b/tests/test_store/test_fsspec.py @@ -313,6 +313,28 @@ async def test_from_url_close_releases_store(self) -> None: assert not store._is_open + async def test_from_url_uses_distinct_filesystem_instances(self) -> None: + """Two from_url() calls for the same URL must not share a cached fs. + + Regression: from_url claims ownership and closes the fs on close(); if it used + the fsspec instance cache, two stores would share one fs and closing one would + tear the shared session out from under the other. skip_instance_cache=True must + give each store its own fs. + """ + url = f"s3://{test_bucket_name}/distinct/" + opts = {"endpoint_url": endpoint_url, "anon": False} + store_a = FsspecStore.from_url(url, storage_options=opts) + store_b = FsspecStore.from_url(url, storage_options=opts) + assert store_a.fs is not store_b.fs + # Closing one leaves the other fully usable. + await store_a.set("probe", cpu.Buffer.from_bytes(b"x")) + store_a.close() + await store_b.set("probe", cpu.Buffer.from_bytes(b"y")) + result = await store_b.get("probe", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == b"y" + store_b.close() + def test_direct_construction_does_not_own_filesystem(self) -> None: """Direct FsspecStore() must not claim ownership — the caller owns the fs.""" try: diff --git a/tests/test_store/test_local.py b/tests/test_store/test_local.py index bdc9b48121..0661208102 100644 --- a/tests/test_store/test_local.py +++ b/tests/test_store/test_local.py @@ -1,8 +1,10 @@ from __future__ import annotations +import asyncio import json import pathlib import re +import threading from typing import TYPE_CHECKING import numpy as np @@ -10,6 +12,7 @@ import zarr from zarr import create_array +from zarr.abc.store import SupportsSetRange from zarr.core.buffer import Buffer, cpu from zarr.core.sync import sync from zarr.storage import LocalStore @@ -162,6 +165,142 @@ def test_get_json_sync_with_prototype_none( result = store._get_json_sync(key, prototype=buffer_cls) assert result == data + def test_supports_set_range(self, store: LocalStore) -> None: + """LocalStore should implement SupportsSetRange.""" + assert isinstance(store, SupportsSetRange) + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + async def test_set_range( + self, store: LocalStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range should overwrite bytes at the given offset.""" + await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + def test_set_range_sync( + self, store: LocalStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range_sync should overwrite bytes at the given offset.""" + sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + + @pytest.mark.parametrize( + ("start", "patch"), + [(9, b"XX"), (10, b"X"), (0, b"ZZZZZZZZZZZ")], + ids=["overhang", "past-end", "too-long"], + ) + async def test_set_range_out_of_bounds( + self, store: LocalStore, start: int, patch: bytes + ) -> None: + """A write that does not fit within the existing value raises, not extends.""" + await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + with pytest.raises(ValueError, match="does not fit within the existing value"): + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + # The file is left unchanged (not zero-filled / extended). + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == b"AAAAAAAAAA" + + async def test_set_range_not_open(self, store_not_open: LocalStore) -> None: + """set_range auto-opens a closed store.""" + assert not store_not_open._is_open + await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert getattr(store_not_open, "_is_open") # noqa: B009 + observed = await self.get(store_not_open, "test/key") + assert observed.to_bytes() == b"XXAAAAAAAA" + + def test_set_range_sync_not_open(self, store_not_open: LocalStore) -> None: + """set_range_sync auto-opens a closed store.""" + assert not store_not_open._is_open + sync(self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) + store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert getattr(store_not_open, "_is_open") # noqa: B009 + observed = sync(self.get(store_not_open, "test/key")) + assert observed.to_bytes() == b"XXAAAAAAAA" + + async def test_set_range_concurrent(self, store: LocalStore) -> None: + """Concurrent set_range calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + await store.set("test/key", cpu.Buffer.from_bytes(bytes(total))) + + async def write_chunk(i: int) -> None: + data = bytes([i] * chunk_size) + await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + + await asyncio.gather(*[write_chunk(i) for i in range(n_writers)]) + + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_set_range_sync_concurrent(self, store: LocalStore) -> None: + """Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + sync(store.set("test/key", cpu.Buffer.from_bytes(bytes(total)))) + + errors: list[Exception] = [] + + def write_chunk(i: int) -> None: + try: + data = bytes([i] * chunk_size) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_lock_file_cleaned_up(self, store: LocalStore) -> None: + """No lock file should remain after set_range_sync completes.""" + sync(store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA"))) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + lock_path = store.root / "test" / "key.__lock__" + assert not lock_path.exists() + @pytest.mark.parametrize("exclusive", [True, False]) def test_atomic_write_successful(tmp_path: pathlib.Path, exclusive: bool) -> None: diff --git a/tests/test_store/test_memory.py b/tests/test_store/test_memory.py index 92e292bef1..c49d2abb36 100644 --- a/tests/test_store/test_memory.py +++ b/tests/test_store/test_memory.py @@ -1,7 +1,9 @@ from __future__ import annotations +import asyncio import json import re +import threading from typing import TYPE_CHECKING, Any import numpy as np @@ -9,6 +11,7 @@ import pytest import zarr +from zarr.abc.store import SupportsSetRange from zarr.core.buffer import Buffer, cpu, gpu from zarr.core.sync import sync from zarr.errors import ZarrUserWarning @@ -127,6 +130,131 @@ def test_get_json_sync_with_prototype_none( result = store._get_json_sync(key, prototype=buffer_cls) assert result == data + def test_supports_set_range(self, store: MemoryStore) -> None: + """MemoryStore should implement SupportsSetRange.""" + assert isinstance(store, SupportsSetRange) + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + async def test_set_range( + self, store: MemoryStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range should overwrite bytes at the given offset.""" + await store.set("test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + + @pytest.mark.parametrize( + ("start", "patch", "expected"), + [ + (0, b"XX", b"XXAAAAAAAA"), + (3, b"XX", b"AAAXXAAAAA"), + (8, b"XX", b"AAAAAAAAXX"), + (0, b"ZZZZZZZZZZ", b"ZZZZZZZZZZ"), + (5, b"B", b"AAAAABAAAA"), + (0, b"BCDE", b"BCDEAAAAAA"), + ], + ids=["start", "middle", "end", "full-overwrite", "single-byte", "multi-byte-start"], + ) + def test_set_range_sync( + self, store: MemoryStore, start: int, patch: bytes, expected: bytes + ) -> None: + """set_range_sync should overwrite bytes at the given offset.""" + store._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") + store.set_range_sync("test/key", cpu.Buffer.from_bytes(patch), start=start) + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + assert result.to_bytes() == expected + + @pytest.mark.parametrize( + ("start", "patch"), + [(9, b"XX"), (10, b"X"), (0, b"ZZZZZZZZZZZ")], + ids=["overhang", "past-end", "too-long"], + ) + async def test_set_range_out_of_bounds( + self, store: MemoryStore, start: int, patch: bytes + ) -> None: + """A write that does not fit within the existing value raises consistently.""" + store._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") + with pytest.raises(ValueError, match="does not fit within the existing value"): + await store.set_range("test/key", cpu.Buffer.from_bytes(patch), start=start) + assert store._store_dict["test/key"].to_bytes() == b"AAAAAAAAAA" + + async def test_set_range_not_open(self, store_not_open: MemoryStore) -> None: + """set_range auto-opens a closed store.""" + assert not store_not_open._is_open + await self.set(store_not_open, "test/key", cpu.Buffer.from_bytes(b"AAAAAAAAAA")) + await store_not_open.set_range("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert getattr(store_not_open, "_is_open") # noqa: B009 + observed = await self.get(store_not_open, "test/key") + assert observed.to_bytes() == b"XXAAAAAAAA" + + def test_set_range_sync_not_open(self, store_not_open: MemoryStore) -> None: + """set_range_sync auto-opens a closed store.""" + assert not store_not_open._is_open + store_not_open._store_dict["test/key"] = cpu.Buffer.from_bytes(b"AAAAAAAAAA") + store_not_open.set_range_sync("test/key", cpu.Buffer.from_bytes(b"XX"), start=0) + assert getattr(store_not_open, "_is_open") # noqa: B009 + assert store_not_open._store_dict["test/key"].to_bytes() == b"XXAAAAAAAA" + + async def test_set_range_concurrent(self, store: MemoryStore) -> None: + """Concurrent set_range calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + await store.set("test/key", cpu.Buffer.from_bytes(bytes(total))) + + async def write_chunk(i: int) -> None: + data = bytes([i] * chunk_size) + await store.set_range("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + + await asyncio.gather(*[write_chunk(i) for i in range(n_writers)]) + + result = await store.get("test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + + def test_set_range_sync_concurrent(self, store: MemoryStore) -> None: + """Concurrent set_range_sync calls to non-overlapping ranges should not corrupt data.""" + n_writers = 10 + chunk_size = 10 + total = n_writers * chunk_size + store._store_dict["test/key"] = cpu.Buffer.from_bytes(bytes(total)) + + errors: list[Exception] = [] + + def write_chunk(i: int) -> None: + try: + data = bytes([i] * chunk_size) + store.set_range_sync("test/key", cpu.Buffer.from_bytes(data), start=i * chunk_size) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=write_chunk, args=(i,)) for i in range(n_writers)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + result = store.get_sync(key="test/key", prototype=cpu.buffer_prototype) + assert result is not None + expected = bytes([i for i in range(n_writers) for _ in range(chunk_size)]) + assert result.to_bytes() == expected + # TODO: fix this warning @pytest.mark.filterwarnings("ignore:Unclosed client session:ResourceWarning") @@ -182,6 +310,18 @@ def test_from_dict(self) -> None: for v in result._store_dict.values(): assert type(v) is gpu.Buffer + def test_set_range_not_supported(self, store: GpuMemoryStore) -> None: + """GpuMemoryStore deliberately does not satisfy SupportsSetRange. + + Capability detection via isinstance must report False so a consumer (e.g. the + sharding codec) does not select it and crash. The methods are disclaimed + (set to None), so isinstance returns False rather than a false positive. + """ + # mypy statically knows GpuMemoryStore cannot satisfy the protocol (the methods + # are None), which is exactly what we want — but it then flags this runtime + # assertion as unreachable. Keep the runtime check as a regression guard. + assert not isinstance(store, SupportsSetRange) # type: ignore[unreachable] + class TestManagedMemoryStore(StoreTests[ManagedMemoryStore, cpu.Buffer]): store_cls = ManagedMemoryStore @@ -451,6 +591,42 @@ async def test_path_prefix_operations(self) -> None: assert result2 is not None assert result2.to_bytes() == b"value" + def test_supports_set_range(self, store: ManagedMemoryStore) -> None: + assert isinstance(store, SupportsSetRange) + + async def test_set_range_applies_path_prefix(self) -> None: + """set_range must prepend the store's path prefix, matching set/get. + + Regression: an unprefixed inherited set_range would target the wrong key. + """ + store = ManagedMemoryStore(name="set-range-path-test", path="subdir") + await store.set("k", self.buffer_cls.from_bytes(b"AAAAAAAAAA")) + # set() writes to the prefixed backing key. + assert "subdir/k" in store._store_dict + await store.set_range("k", self.buffer_cls.from_bytes(b"XX"), start=2) + store.set_range_sync("k", self.buffer_cls.from_bytes(b"YY"), start=6) + # Both writes landed on the same prefixed value that set/get use. + observed = await store.get("k") + assert observed is not None + assert observed.to_bytes() == b"AAXXAAYYAA" + assert store._store_dict["subdir/k"].to_bytes() == b"AAXXAAYYAA" + + def test_set_range_locks_shared_by_name(self) -> None: + """Instances sharing a backing dict (same name) share the set_range lock dicts. + + The registry shares the data across same-name handles, so it must share the + locks too — otherwise concurrent set_range from two handles would not serialize. + """ + a = ManagedMemoryStore(name="lock-share-test") + b = ManagedMemoryStore.from_url("memory://lock-share-test") + c = a.with_read_only(not a.read_only) + assert a._key_locks is b._key_locks + assert a._key_locks_sync is b._key_locks_sync + assert a._key_locks is c._key_locks + # A differently named store has independent locks. + other = ManagedMemoryStore(name="lock-share-test-other") + assert other._key_locks is not a._key_locks + async def test_path_list_operations(self) -> None: """Test that list operations filter by path prefix.""" store = ManagedMemoryStore(name="list-test")