Skip to content

Commit abd3a2f

Browse files
committed
adjust to xarray datatree
1 parent 27bb4a7 commit abd3a2f

33 files changed

+59
-85
lines changed

docs/tutorials/notebooks

Submodule notebooks updated 48 files

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828
"dask>=2024.4.1",
2929
"fsspec<=2023.6",
3030
"geopandas>=0.14",
31-
"multiscale_spatial_image>=1.0.0",
31+
#"multiscale_spatial_image>=1.0.0", Uncomment when new release using xr.DataTree is out
3232
"networkx",
3333
"numba",
3434
"numpy",
@@ -43,8 +43,7 @@ dependencies = [
4343
"scikit-image",
4444
"scipy",
4545
"typing_extensions>=4.8.0",
46-
"xarray",
47-
"xarray-datatree",
46+
"xarray>=2024.10.0",
4847
"xarray-schema",
4948
"xarray-spatial>=0.3.5",
5049
"zarr",

src/spatialdata/_core/_deepcopy.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from dask.array.core import Array as DaskArray
88
from dask.array.core import from_array
99
from dask.dataframe import DataFrame as DaskDataFrame
10-
from datatree import DataTree
1110
from geopandas import GeoDataFrame
12-
from xarray import DataArray
11+
from xarray import DataArray, DataTree
1312

1413
from spatialdata._core.spatialdata import SpatialData
1514
from spatialdata.models._utils import SpatialElement

src/spatialdata/_core/centroids.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
import pandas as pd
88
import xarray as xr
99
from dask.dataframe import DataFrame as DaskDataFrame
10-
from datatree import DataTree
1110
from geopandas import GeoDataFrame
1211
from shapely import MultiPolygon, Point, Polygon
13-
from xarray import DataArray
12+
from xarray import DataArray, DataTree
1413

1514
from spatialdata._core.operations.transform import transform
1615
from spatialdata.models import get_axes_names

src/spatialdata/_core/data_extent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import numpy as np
99
import pandas as pd
1010
from dask.dataframe import DataFrame as DaskDataFrame
11-
from datatree import DataTree
1211
from geopandas import GeoDataFrame
1312
from shapely import MultiPolygon, Point, Polygon
14-
from xarray import DataArray
13+
from xarray import DataArray, DataTree
1514

1615
from spatialdata._core.operations.transform import transform
1716
from spatialdata._core.spatialdata import SpatialData

src/spatialdata/_core/operations/_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5-
from datatree import DataTree
6-
from xarray import DataArray
5+
from xarray import DataArray, DataTree
76

87
from spatialdata.models import SpatialElement
98

src/spatialdata/_core/operations/aggregate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
import numpy as np
1111
import pandas as pd
1212
from dask.dataframe import DataFrame as DaskDataFrame
13-
from datatree import DataTree
1413
from geopandas import GeoDataFrame
1514
from scipy import sparse
1615
from shapely import Point
17-
from xarray import DataArray
16+
from xarray import DataArray, DataTree
1817
from xrspatial import zonal_stats
1918

2019
from spatialdata._core.operations._utils import _parse_element

src/spatialdata/_core/operations/map.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import dask.array as da
88
from dask.array.overlap import coerce_depth
9-
from datatree import DataTree
10-
from xarray import DataArray
9+
from xarray import DataArray, DataTree
1110

1211
from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims
1312
from spatialdata.transformations import get_transformation

src/spatialdata/_core/operations/rasterize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import numpy as np
66
from dask.array import Array as DaskArray
77
from dask.dataframe import DataFrame as DaskDataFrame
8-
from datatree import DataTree
98
from geopandas import GeoDataFrame
109
from shapely import Point
11-
from xarray import DataArray
10+
from xarray import DataArray, DataTree
1211

1312
from spatialdata._core.operations._utils import _parse_element
1413
from spatialdata._core.operations.transform import transform

src/spatialdata/_core/operations/transform.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
import numpy as np
1111
from dask.array.core import Array as DaskArray
1212
from dask.dataframe import DataFrame as DaskDataFrame
13-
from datatree import DataTree
1413
from geopandas import GeoDataFrame
1514
from shapely import Point
16-
from xarray import DataArray
15+
from xarray import DataArray, Dataset, DataTree
1716

1817
from spatialdata._core.spatialdata import SpatialData
1918
from spatialdata._types import ArrayLike
@@ -393,8 +392,11 @@ def _(
393392
raster_translation = raster_translation_single_scale
394393
# we set a dummy empty dict for the transformation that will be replaced with the correct transformation for
395394
# each scale later in this function, when calling set_transformation()
396-
transformed_dict[k] = DataArray(transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}})
395+
transformed_dict[k] = Dataset(
396+
{"image": DataArray(transformed_dask, dims=xdata.dims, name=xdata.name, attrs={TRANSFORM_KEY: {}})}
397+
)
397398
if channel_names is not None:
399+
# This expression returns a dataset now.
398400
transformed_dict[k] = transformed_dict[k].assign_coords(c=channel_names)
399401

400402
# mypy thinks that schema could be ShapesModel, PointsModel, ...

0 commit comments

Comments
 (0)