diff --git a/changes/4001.misc.md b/changes/4001.misc.md new file mode 100644 index 0000000000..e90f16a9e8 --- /dev/null +++ b/changes/4001.misc.md @@ -0,0 +1,4 @@ +Restore sharding write performance for shards with many chunks. The +`subchunk_write_order` feature inadvertently rebuilt the per-shard chunk +coordinate grid (up to tens of thousands of tuples) on every partial write; +these coordinates are now cached, restoring throughput to its previous level. diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 33c8602ecb..bcc783f386 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -46,8 +46,11 @@ BasicIndexer, ChunkProjection, SelectorTuple, + _lexicographic_order, + _lexicographic_order_keys, c_order_iter, get_indexer, + lexicographic_order_iter, morton_order_iter, ) from zarr.core.metadata.v3 import ( @@ -263,29 +266,30 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[tuple[int, ...]]: return c_order_iter(self.index.chunks_per_shard) - def to_dict_vectorized( - self, - chunk_coords_array: npt.NDArray[np.integer[Any]], - ) -> dict[tuple[int, ...], Buffer | None]: + def to_dict_vectorized(self) -> dict[tuple[int, ...], Buffer | None]: """Build a dict of chunk coordinates to buffers using vectorized lookup. - Parameters - ---------- - chunk_coords_array : ndarray of shape (n_chunks, n_dims) - Array of chunk coordinates for vectorized index lookup. + The full per-shard chunk coordinate grid (both the array used for the + vectorized index lookup and the plain tuples used as dict keys) is + cached on `chunks_per_shard`, so neither is rebuilt on every call. For a + shard with tens of thousands of chunks this avoids reconstructing that + many tuples on every partial write. Returns ------- dict mapping chunk coordinate tuples to Buffer or None """ + chunks_per_shard = self.index.chunks_per_shard + chunk_coords_array = _lexicographic_order(chunks_per_shard) + chunk_coords_keys = _lexicographic_order_keys(chunks_per_shard) starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array) result: dict[tuple[int, ...], Buffer | None] = {} - for i, coords in enumerate(chunk_coords_array): + for i, coords in enumerate(chunk_coords_keys): if valid[i]: - result[tuple(coords.ravel())] = self.buf[int(starts[i]) : int(ends[i])] + result[coords] = self.buf[int(starts[i]) : int(ends[i])] else: - result[tuple(coords.ravel())] = None + result[coords] = None return result @@ -533,7 +537,7 @@ def _subchunk_order_iter( case "morton": subchunk_iter = morton_order_iter(chunks_per_shard) case "lexicographic": - subchunk_iter = np.ndindex(chunks_per_shard) + subchunk_iter = lexicographic_order_iter(chunks_per_shard) case "colexicographic": subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1])) case "unordered": @@ -612,10 +616,10 @@ async def _encode_partial_single( chunks_per_shard=chunks_per_shard, ) shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) - # Use vectorized lookup for better performance - shard_dict = shard_reader.to_dict_vectorized( - np.array(list(self._subchunk_order_iter(chunks_per_shard, "lexicographic"))) - ) + # Use vectorized lookup for better performance. The lexicographic + # coordinate array and keys are cached, so neither is rebuilt on + # every write. + shard_dict = shard_reader.to_dict_vectorized() await self.codec_pipeline.write( [ diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index cb81164209..ab658a4924 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -1584,6 +1584,33 @@ def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]] return iter(_morton_order_keys(tuple(chunk_shape))) +@lru_cache(maxsize=16) +def _lexicographic_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]: + # Lexicographic (C-order) coordinates, computed vectorized and cached so that + # the sharding codec's per-shard chunk grid is not rebuilt on every call. + # Equivalent to `np.array(list(np.ndindex(chunk_shape)))` but without the + # Python-level iteration over every coordinate. + n_dims = len(chunk_shape) + if n_dims == 0: + # A 0-d shard holds a single chunk addressed by the empty coordinate, so + # the coordinate array has one row and zero columns. np.indices(()) cannot + # express this, so build it directly. Matches list(np.ndindex(())) == [()]. + order = np.empty((1, 0), dtype=np.intp) + else: + order = np.indices(chunk_shape, dtype=np.intp).reshape(n_dims, -1).T + order.flags.writeable = False + return order + + +@lru_cache(maxsize=16) +def _lexicographic_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]: + return tuple(tuple(int(x) for x in row) for row in _lexicographic_order(chunk_shape)) + + +def lexicographic_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]: + return iter(_lexicographic_order_keys(tuple(chunk_shape))) + + def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]: return itertools.product(*(range(x) for x in chunks_per_shard)) diff --git a/tests/test_codecs/test_sharding.py b/tests/test_codecs/test_sharding.py index 74e4a7e0d5..811de031e8 100644 --- a/tests/test_codecs/test_sharding.py +++ b/tests/test_codecs/test_sharding.py @@ -992,3 +992,39 @@ def test_shard_index_get_chunk_slices_vectorized(chunks_per_shard: tuple[int, .. assert starts[0] == 10 assert ends[0] == 14 np.testing.assert_array_equal(starts[~expected_valid], MAX_UINT_64) + + +@pytest.mark.parametrize("chunks_per_shard", [(), (3,), (2, 3)]) +def test_shard_reader_to_dict_vectorized(chunks_per_shard: tuple[int, ...]) -> None: + """to_dict_vectorized derives its own coords and maps present chunks to buffers, empty to None. + + The reader is given the full per-shard chunk grid implicitly (it reads + ``chunks_per_shard`` off its own index), so the result must contain every + lexicographic coordinate as a key, with the stored bytes for present chunks + and ``None`` for empty ones. + """ + all_coords = list(c_order_iter(chunks_per_shard)) + # Lay two chunks back-to-back in the buffer; leave the rest (if any) empty. + payload = b"abcdXY" + index = _ShardIndex.create_empty(chunks_per_shard) + index.set_chunk_slice(all_coords[0], slice(0, 4)) + present = {all_coords[0]: payload[0:4]} + if len(all_coords) > 1: + index.set_chunk_slice(all_coords[1], slice(4, 6)) + present[all_coords[1]] = payload[4:6] + + reader = _ShardReader() + reader.index = index + reader.buf = default_buffer_prototype().buffer.from_bytes(payload) + + result = reader.to_dict_vectorized() + + # Every lexicographic coordinate is present as a key, in order. + assert list(result.keys()) == all_coords + for coords in all_coords: + buf = result[coords] + if coords in present: + assert buf is not None + assert buf.to_bytes() == present[coords] + else: + assert buf is None