|
| 1 | +from pathlib import Path |
| 2 | +from typing import TYPE_CHECKING, Union |
| 3 | + |
| 4 | +import dask |
| 5 | +import dask.array |
| 6 | +import numpy as np |
| 7 | +import pytest |
| 8 | +import xarray as xr |
| 9 | +import zarr |
| 10 | +from datatree import DataTree |
| 11 | + |
| 12 | +from spatialdata import SpatialData |
| 13 | +from spatialdata._io import write_image |
| 14 | +from spatialdata._io.format import CurrentRasterFormat |
| 15 | +from spatialdata.models import Image2DModel |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + import _pytest.fixtures |
| 19 | + |
| 20 | + |
| 21 | +@pytest.fixture |
| 22 | +def sdata_with_image(request: "_pytest.fixtures.SubRequest", tmp_path: Path) -> SpatialData: |
| 23 | + params = request.param if request.param is not None else {} |
| 24 | + width = params.get("width", 2048) |
| 25 | + chunksize = params.get("chunk_size", 1024) |
| 26 | + scale_factors = params.get("scale_factors", (2,)) |
| 27 | + # Create a disk-backed Dask array for scale 0. |
| 28 | + npg = np.random.default_rng(0) |
| 29 | + array = npg.integers(low=0, high=2**16, size=(1, width, width)) |
| 30 | + array_path = tmp_path / "image.zarr" |
| 31 | + dask.array.from_array(array).rechunk(chunksize).to_zarr(array_path) |
| 32 | + array_backed = dask.array.from_zarr(array_path) |
| 33 | + # Create an in-memory SpatialData with disk-backed scale 0. |
| 34 | + image = Image2DModel.parse(array_backed, dims=("c", "y", "x"), scale_factors=scale_factors, chunks=chunksize) |
| 35 | + return SpatialData(images={"image": image}) |
| 36 | + |
| 37 | + |
| 38 | +def count_chunks(array: Union[xr.DataArray, xr.Dataset, DataTree]) -> int: |
| 39 | + if isinstance(array, DataTree): |
| 40 | + array = array.ds |
| 41 | + # From `chunksizes`, we get only the number of chunks per axis. |
| 42 | + # By multiplying them, we get the total number of chunks in 2D/3D. |
| 43 | + return np.prod([len(chunk_sizes) for chunk_sizes in array.chunksizes.values()]) |
| 44 | + |
| 45 | + |
| 46 | +@pytest.mark.parametrize( |
| 47 | + ("sdata_with_image",), |
| 48 | + [ |
| 49 | + ({"width": 32, "chunk_size": 16, "scale_factors": (2,)},), |
| 50 | + ({"width": 64, "chunk_size": 16, "scale_factors": (2, 2)},), |
| 51 | + ({"width": 128, "chunk_size": 16, "scale_factors": (2, 2, 2)},), |
| 52 | + ({"width": 256, "chunk_size": 16, "scale_factors": (2, 2, 2, 2)},), |
| 53 | + ], |
| 54 | + indirect=["sdata_with_image"], |
| 55 | +) |
| 56 | +def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_path: Path, mocker): |
| 57 | + # Writing multiscale images with several pyramid levels should be efficient. |
| 58 | + # Specifically, it should not read the input image more often than necessary |
| 59 | + # (see issue https://github.com/scverse/spatialdata/issues/577). |
| 60 | + # Instead of measuring the time (which would have high variation if not using big datasets), |
| 61 | + # we watch the number of read and write accesses and compare to the theoretical number. |
| 62 | + zarr_chunk_write_spy = mocker.spy(zarr.core.Array, "__setitem__") |
| 63 | + zarr_chunk_read_spy = mocker.spy(zarr.core.Array, "__getitem__") |
| 64 | + |
| 65 | + image_name, image = next(iter(sdata_with_image.images.items())) |
| 66 | + element_type_group = zarr.group(store=tmp_path / "sdata.zarr", path="/images") |
| 67 | + |
| 68 | + write_image( |
| 69 | + image=image, |
| 70 | + group=element_type_group, |
| 71 | + name=image_name, |
| 72 | + format=CurrentRasterFormat(), |
| 73 | + ) |
| 74 | + |
| 75 | + # The number of chunks of scale level 0 |
| 76 | + num_chunks_scale0 = count_chunks(image.scale0 if isinstance(image, DataTree) else image) |
| 77 | + # The total number of chunks of all scale levels |
| 78 | + num_chunks_all_scales = ( |
| 79 | + sum(count_chunks(pyramid) for pyramid in image.children.values()) |
| 80 | + if isinstance(image, DataTree) |
| 81 | + else count_chunks(image) |
| 82 | + ) |
| 83 | + |
| 84 | + actual_num_chunk_writes = zarr_chunk_write_spy.call_count |
| 85 | + actual_num_chunk_reads = zarr_chunk_read_spy.call_count |
| 86 | + assert actual_num_chunk_writes == num_chunks_all_scales |
| 87 | + assert actual_num_chunk_reads == num_chunks_scale0 |
0 commit comments