Skip to content

Commit a126e58

Browse files
authored
Improve performance of writing pyramids (scverse#726)
* Add performance test for issue 577 * Compute all pyramid levels at once * Add change log entry
1 parent c09f35a commit a126e58

4 files changed

Lines changed: 94 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning][].
1515
- Added `shortest_path` parameter to `get_transformation_between_coordinate_systems`
1616
- Added `get_pyramid_levels()` utils API
1717
- Improved ergonomics of `concatenate()` when element names are non-unique #720
18+
- Improved performance of writing images with multiscales #577
1819

1920
## [0.2.3] - 2024-09-25
2021

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ docs = [
6868
test = [
6969
"pytest",
7070
"pytest-cov",
71+
"pytest-mock",
7172
]
7273
torch = [
7374
"torch"

src/spatialdata/_io/io_raster.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pathlib import Path
22
from typing import Any, Literal, Optional, Union
33

4+
import dask.array as da
45
import numpy as np
56
import zarr
67
from datatree import DataTree
@@ -195,15 +196,18 @@ def _get_group_for_writing_transformations() -> zarr.Group:
195196
# coords = iterate_pyramid_levels(raster_data, "coords")
196197
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=format)
197198
storage_options = [{"chunks": chunk} for chunk in chunks]
198-
write_multi_scale_ngff(
199+
dask_delayed = write_multi_scale_ngff(
199200
pyramid=data,
200201
group=group_data,
201202
fmt=format,
202203
axes=parsed_axes,
203204
coordinate_transformations=None,
204205
storage_options=storage_options,
205206
**metadata,
207+
compute=False,
206208
)
209+
# Compute all pyramid levels at once to allow Dask to optimize the computational graph.
210+
da.compute(*dask_delayed)
207211
assert transformations is not None
208212
overwrite_coordinate_transformations_raster(
209213
group=_get_group_for_writing_transformations(), transformations=transformations, axes=tuple(input_axes)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from pathlib import Path
2+
from typing import TYPE_CHECKING, Union
3+
4+
import dask
5+
import dask.array
6+
import numpy as np
7+
import pytest
8+
import xarray as xr
9+
import zarr
10+
from datatree import DataTree
11+
12+
from spatialdata import SpatialData
13+
from spatialdata._io import write_image
14+
from spatialdata._io.format import CurrentRasterFormat
15+
from spatialdata.models import Image2DModel
16+
17+
if TYPE_CHECKING:
18+
import _pytest.fixtures
19+
20+
21+
@pytest.fixture
22+
def sdata_with_image(request: "_pytest.fixtures.SubRequest", tmp_path: Path) -> SpatialData:
23+
params = request.param if request.param is not None else {}
24+
width = params.get("width", 2048)
25+
chunksize = params.get("chunk_size", 1024)
26+
scale_factors = params.get("scale_factors", (2,))
27+
# Create a disk-backed Dask array for scale 0.
28+
npg = np.random.default_rng(0)
29+
array = npg.integers(low=0, high=2**16, size=(1, width, width))
30+
array_path = tmp_path / "image.zarr"
31+
dask.array.from_array(array).rechunk(chunksize).to_zarr(array_path)
32+
array_backed = dask.array.from_zarr(array_path)
33+
# Create an in-memory SpatialData with disk-backed scale 0.
34+
image = Image2DModel.parse(array_backed, dims=("c", "y", "x"), scale_factors=scale_factors, chunks=chunksize)
35+
return SpatialData(images={"image": image})
36+
37+
38+
def count_chunks(array: Union[xr.DataArray, xr.Dataset, DataTree]) -> int:
39+
if isinstance(array, DataTree):
40+
array = array.ds
41+
# From `chunksizes`, we get only the number of chunks per axis.
42+
# By multiplying them, we get the total number of chunks in 2D/3D.
43+
return np.prod([len(chunk_sizes) for chunk_sizes in array.chunksizes.values()])
44+
45+
46+
@pytest.mark.parametrize(
47+
("sdata_with_image",),
48+
[
49+
({"width": 32, "chunk_size": 16, "scale_factors": (2,)},),
50+
({"width": 64, "chunk_size": 16, "scale_factors": (2, 2)},),
51+
({"width": 128, "chunk_size": 16, "scale_factors": (2, 2, 2)},),
52+
({"width": 256, "chunk_size": 16, "scale_factors": (2, 2, 2, 2)},),
53+
],
54+
indirect=["sdata_with_image"],
55+
)
56+
def test_write_image_multiscale_performance(sdata_with_image: SpatialData, tmp_path: Path, mocker):
57+
# Writing multiscale images with several pyramid levels should be efficient.
58+
# Specifically, it should not read the input image more often than necessary
59+
# (see issue https://github.com/scverse/spatialdata/issues/577).
60+
# Instead of measuring the time (which would have high variation if not using big datasets),
61+
# we watch the number of read and write accesses and compare to the theoretical number.
62+
zarr_chunk_write_spy = mocker.spy(zarr.core.Array, "__setitem__")
63+
zarr_chunk_read_spy = mocker.spy(zarr.core.Array, "__getitem__")
64+
65+
image_name, image = next(iter(sdata_with_image.images.items()))
66+
element_type_group = zarr.group(store=tmp_path / "sdata.zarr", path="/images")
67+
68+
write_image(
69+
image=image,
70+
group=element_type_group,
71+
name=image_name,
72+
format=CurrentRasterFormat(),
73+
)
74+
75+
# The number of chunks of scale level 0
76+
num_chunks_scale0 = count_chunks(image.scale0 if isinstance(image, DataTree) else image)
77+
# The total number of chunks of all scale levels
78+
num_chunks_all_scales = (
79+
sum(count_chunks(pyramid) for pyramid in image.children.values())
80+
if isinstance(image, DataTree)
81+
else count_chunks(image)
82+
)
83+
84+
actual_num_chunk_writes = zarr_chunk_write_spy.call_count
85+
actual_num_chunk_reads = zarr_chunk_read_spy.call_count
86+
assert actual_num_chunk_writes == num_chunks_all_scales
87+
assert actual_num_chunk_reads == num_chunks_scale0

0 commit comments

Comments
 (0)