Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/4054.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Hypothesis property tests for block and mask indexing (`test_block_indexing`, `test_mask_indexing`), along with a `block_indices` strategy in `zarr.testing.strategies`. These extend the existing randomized indexing coverage (basic, orthogonal, and vectorized) to the block and mask selection methods.
61 changes: 61 additions & 0 deletions src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,67 @@ def orthogonal_indices(
return tuple(zindexer), tuple(np.broadcast_arrays(*npindexer))


@st.composite
def block_indices(
draw: st.DrawFn, *, chunk_grid_shape: tuple[int, ...], chunks: tuple[int, ...]
) -> tuple[tuple[int | slice, ...], tuple[slice, ...]]:
"""
Strategy for block-selection indexers over a *regular* chunk grid.

Block indexing is basic indexing applied to the block grid (the grid of
chunks), so each axis is drawn with ``basic_indices`` over that axis's chunk
count from ``chunk_grid_shape`` (e.g. ``Array.cdata_shape``), mirroring how
``orthogonal_indices`` reuses ``basic_indices`` per axis. Block indexing only
supports integers and step-1 slices whose start references an existing chunk,
so strided slices and slices starting at the grid edge are filtered out. The
array-space translation assumes a regular (uniform) chunk grid; an over-long
stop into a smaller last chunk is left for numpy to clamp when the oracle is
applied.

Returns
-------
block_indexer
A per-axis tuple of ints / step-1 slices addressing whole chunks,
suitable for ``Array.blocks`` / ``get_block_selection`` / ``set_block_selection``.
array_indexer
The equivalent array-space selection (a tuple of slices) for indexing
the corresponding numpy array, used as the comparison oracle.
"""

def supported(nchunks: int) -> Callable[[tuple[Any, ...]], bool]:
# Block indexing only accepts step-1 slices whose start references an
# existing chunk (a slice starting at nchunks raises, unlike numpy).
def predicate(value: tuple[Any, ...]) -> bool:
dim_sel = value[0]
if isinstance(dim_sel, slice):
if dim_sel.step not in (None, 1):
return False
start = dim_sel.start or 0
return 0 <= (start + nchunks if start < 0 else start) < nchunks
return True

return predicate

block_indexer: list[int | slice] = []
array_indexer: list[slice] = []
for chunk, nchunks in zip(chunks, chunk_grid_shape, strict=True):
(dim_sel,) = draw(
basic_indices(min_dims=1, shape=(nchunks,), allow_ellipsis=False)
# normalize bare ints / slices to a 1-tuple, skip the empty tuple
.map(lambda x: (x,) if not isinstance(x, tuple) else x)
.filter(bool)
.filter(supported(nchunks))
)
block_indexer.append(dim_sel)
if isinstance(dim_sel, slice):
start, stop, _ = dim_sel.indices(nchunks)
array_indexer.append(slice(start * chunk, stop * chunk))
else:
block = dim_sel % nchunks
array_indexer.append(slice(block * chunk, (block + 1) * chunk))
return tuple(block_indexer), tuple(array_indexer)


def key_ranges(
keys: SearchStrategy[str] = node_names, max_size: int = sys.maxsize
) -> SearchStrategy[list[tuple[str, RangeByteRequest]]]:
Expand Down
61 changes: 61 additions & 0 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
array_metadata,
arrays,
basic_indices,
block_indices,
complex_rectilinear_arrays,
np_array_and_chunks,
numpy_arrays,
orthogonal_indices,
rectilinear_arrays,
Expand Down Expand Up @@ -230,6 +232,65 @@ async def test_vindex(data: st.DataObject) -> None:
# note: async vindex setitem not yet implemented


@settings(deadline=None)
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@given(data=st.data())
def test_mask_indexing(data: st.DataObject) -> None:
zarray = data.draw(st.one_of(simple_arrays(), rectilinear_arrays()))
nparray = zarray[:]
mask = data.draw(npst.arrays(dtype=np.bool_, shape=st.just(nparray.shape)))

expected = nparray[mask]

# sync get, via both the dedicated method and the vindex interface
assert_array_equal(expected, zarray.get_mask_selection(mask))
assert_array_equal(expected, zarray.vindex[mask])

# sync set, via both interfaces
assume(zarray.shards is None) # GH2834
new_data = data.draw(numpy_arrays(shapes=st.just(expected.shape), dtype=nparray.dtype))
nparray[mask] = new_data
zarray.set_mask_selection(mask, new_data)
assert_array_equal(nparray, zarray[:])

zarray.vindex[mask] = new_data
assert_array_equal(nparray, zarray[:])


@settings(deadline=None)
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@given(data=st.data())
def test_block_indexing(data: st.DataObject) -> None:
# Block indexing addresses whole chunks on a regular grid; the array-space
# oracle in block_indices() assumes regular, unsharded chunks, so build the
# array directly from a regular chunking rather than drawing one that might
# be rectilinear or sharded.
nparray, chunks = data.draw(
np_array_and_chunks(arrays=numpy_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
)
store = data.draw(stores)
zarray = zarr.create_array(store=store, shape=nparray.shape, chunks=chunks, dtype=nparray.dtype)
zarray[...] = nparray

block_indexer, array_indexer = data.draw(
block_indices(chunk_grid_shape=zarray.cdata_shape, chunks=chunks)
)
expected = nparray[array_indexer]

# sync get, via both the .blocks interface and the dedicated method
assert_array_equal(expected, zarray.blocks[block_indexer])
assert_array_equal(expected, zarray.get_block_selection(block_indexer))

# sync set, via both interfaces
new_data = data.draw(numpy_arrays(shapes=st.just(expected.shape), dtype=nparray.dtype))
nparray[array_indexer] = new_data
zarray.blocks[block_indexer] = new_data
assert_array_equal(nparray, zarray[:])

zarray.set_block_selection(block_indexer, new_data)
assert_array_equal(nparray, zarray[:])


@given(store=stores, meta=array_metadata()) # type: ignore[misc]
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
async def test_roundtrip_array_metadata_from_store(
Expand Down
Loading