|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +import zarr |
| 6 | +from zarr.core.buffer.core import default_buffer_prototype |
| 7 | +from zarr.core.indexing import BasicIndexer |
| 8 | +from zarr.storage import MemoryStore |
| 9 | + |
| 10 | + |
| 11 | +@pytest.mark.parametrize( |
| 12 | + ("write_slice", "read_slice", "expected_statuses"), |
| 13 | + [ |
| 14 | + # Write all chunks, read all — all present |
| 15 | + (slice(None), slice(None), ("present", "present", "present")), |
| 16 | + # Write first chunk only, read all — first present, rest missing |
| 17 | + (slice(0, 2), slice(None), ("present", "missing", "missing")), |
| 18 | + # Write nothing, read all — all missing |
| 19 | + (None, slice(None), ("missing", "missing", "missing")), |
| 20 | + ], |
| 21 | +) |
| 22 | +async def test_read_returns_get_results( |
| 23 | + write_slice: slice | None, |
| 24 | + read_slice: slice, |
| 25 | + expected_statuses: tuple[str, ...], |
| 26 | +) -> None: |
| 27 | + """ |
| 28 | + Test that CodecPipeline.read returns a tuple of GetResult with correct statuses. |
| 29 | + """ |
| 30 | + store = MemoryStore() |
| 31 | + arr = zarr.open_array(store, mode="w", shape=(6,), chunks=(2,), dtype="int64", fill_value=-1) |
| 32 | + |
| 33 | + if write_slice is not None: |
| 34 | + arr[write_slice] = 0 |
| 35 | + |
| 36 | + async_arr = arr._async_array |
| 37 | + pipeline = async_arr.codec_pipeline |
| 38 | + metadata = async_arr.metadata |
| 39 | + |
| 40 | + prototype = default_buffer_prototype() |
| 41 | + config = async_arr.config |
| 42 | + indexer = BasicIndexer( |
| 43 | + read_slice, |
| 44 | + shape=metadata.shape, |
| 45 | + chunk_grid=metadata.chunk_grid, |
| 46 | + ) |
| 47 | + |
| 48 | + out_buffer = prototype.nd_buffer.empty( |
| 49 | + shape=indexer.shape, |
| 50 | + dtype=metadata.dtype.to_native_dtype(), |
| 51 | + order=config.order, |
| 52 | + ) |
| 53 | + |
| 54 | + results = await pipeline.read( |
| 55 | + [ |
| 56 | + ( |
| 57 | + async_arr.store_path / metadata.encode_chunk_key(chunk_coords), |
| 58 | + metadata.get_chunk_spec(chunk_coords, config, prototype=prototype), |
| 59 | + chunk_selection, |
| 60 | + out_selection, |
| 61 | + is_complete_chunk, |
| 62 | + ) |
| 63 | + for chunk_coords, chunk_selection, out_selection, is_complete_chunk in indexer |
| 64 | + ], |
| 65 | + out_buffer, |
| 66 | + drop_axes=indexer.drop_axes, |
| 67 | + ) |
| 68 | + |
| 69 | + assert len(results) == len(expected_statuses) |
| 70 | + for result, expected_status in zip(results, expected_statuses, strict=True): |
| 71 | + assert result["status"] == expected_status |
0 commit comments