Skip to content

Commit fcbf971

Browse files
committed
add decorator skipping nodes without dimensions
1 parent abd3a2f commit fcbf971

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

src/spatialdata/_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,22 @@ def _error_message_add_element() -> None:
311311
"write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more "
312312
"ergonomic in a follow up PR."
313313
)
314+
315+
316+
def skip_non_dimension_nodes(func: Callable[[Dataset], Dataset]) -> Callable[[Dataset], Dataset]:
317+
"""Skip nodes in Datatree that do not contain dimensions.
318+
319+
This function implements the workaround of https://github.com/pydata/xarray/issues/9693. In particular,
320+
we need this because of our DataTree representing multiscale image having a root node that does not have
321+
dimensions. Several functions need to be mapped over the datasets in the datatree that depend on having
322+
dimensions, e.g. a transpose.
323+
"""
324+
325+
@functools.wraps(func)
326+
def _func(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:
327+
# check if dimensions are present otherwise return verbatim
328+
if len(ds.dims) == 0:
329+
return ds
330+
return func(ds, *args, **kwargs)
331+
332+
return _func

tests/utils/test_element_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import xarray
66
from xarray import DataArray, DataTree
77

8-
from spatialdata._utils import unpad_raster
8+
from spatialdata._utils import skip_non_dimension_nodes, unpad_raster
99
from spatialdata.models import get_model
1010
from spatialdata.transformations import Affine
1111

@@ -64,3 +64,19 @@ def test_unpad_raster(images, labels) -> None:
6464
raise e
6565
else:
6666
raise ValueError(f"Unknown type: {type(raster)}")
67+
68+
69+
def test_skip_nodes(images):
70+
multiscale_img = images["image2d_multiscale"]
71+
72+
@skip_non_dimension_nodes
73+
def transpose(ds, *args, **kwargs):
74+
return ds.transpose(*args, **kwargs)
75+
76+
for scale in list(multiscale_img.keys()):
77+
assert multiscale_img[scale]["image"].dims == ("c", "y", "x")
78+
79+
# applying this function without skipping the root node would fail as the root node does not have dimensions.
80+
result = images["image2d_multiscale"].map_over_datasets(transpose, "y", "x", "c")
81+
for scale in list(result.keys()):
82+
assert result[scale]["image"].dims == ("y", "x", "c")

0 commit comments

Comments
 (0)