diff --git a/docs/history.rst b/docs/history.rst index 1ec98bff..5bf1af8a 100644 --- a/docs/history.rst +++ b/docs/history.rst @@ -1,6 +1,10 @@ History ======= +Unreleased +---------- +- ENH: Add write support for Zarr spatial and proj conventions + 0.22.0 ------ - ENH: Add read support for Zarr spatial and proj conventions (#900) diff --git a/rioxarray/_convention/zarr.py b/rioxarray/_convention/zarr.py index 0784ad54..bff5e75c 100644 --- a/rioxarray/_convention/zarr.py +++ b/rioxarray/_convention/zarr.py @@ -8,6 +8,7 @@ from typing import Optional, Union import rasterio.crs +import rasterio.transform import xarray from affine import Affine @@ -183,6 +184,42 @@ def _parse_transform_from_attrs( return None +# ============================================================================ +# Writing utilities +# ============================================================================ + +_CONVENTION_DICTS = {"proj:": PROJ_CONVENTION, "spatial:": SPATIAL_CONVENTION} + + +def add_convention_declaration(attrs: dict, convention_name: str) -> dict: + """ + Add a convention to the zarr_conventions list in attrs, idempotent. + + Parameters + ---------- + attrs : dict + Attributes dictionary to modify in place + convention_name : str + Name of the convention to declare (e.g., "proj:" or "spatial:") + + Returns + ------- + dict + The modified attrs dict + """ + if has_convention_declared(attrs, convention_name): + return attrs + zarr_conventions = list(attrs.get("zarr_conventions") or []) + zarr_conventions.append(_CONVENTION_DICTS[convention_name]) + attrs["zarr_conventions"] = zarr_conventions + return attrs + + +def format_spatial_transform(affine: Affine) -> list: + """Convert Affine to spatial:transform array [a, b, c, d, e, f].""" + return [affine.a, affine.b, affine.c, affine.d, affine.e, affine.f] + + # ============================================================================ # ZarrConvention class implementing ConventionProtocol # ============================================================================ @@ -270,9 +307,7 @@ def write_crs( **kwargs, # pylint: disable=unused-argument ) -> Union[xarray.Dataset, xarray.DataArray]: """ - Write CRS using Zarr conventions. - - Note: Writing support will be implemented in a future PR. + Write CRS using Zarr proj: convention. Parameters ---------- @@ -281,22 +316,17 @@ def write_crs( crs : rasterio.crs.CRS CRS to write **kwargs - Additional convention-specific parameters + Additional convention-specific parameters (e.g., grid_mapping_name for CF; + silently ignored here) Returns ------- xarray.Dataset or xarray.DataArray Object with CRS written - - Raises - ------ - NotImplementedError - Zarr write support is not yet implemented """ - raise NotImplementedError( - "Zarr CRS writing is not yet implemented. " - "Use Convention.CF for writing or wait for a future release." - ) + add_convention_declaration(obj.attrs, "proj:") + obj.attrs["proj:wkt2"] = crs.to_wkt() + return obj @classmethod def write_transform( @@ -307,9 +337,7 @@ def write_transform( **kwargs, # pylint: disable=unused-argument ) -> Union[xarray.Dataset, xarray.DataArray]: """ - Write transform using Zarr conventions. - - Note: Writing support will be implemented in a future PR. + Write transform using Zarr spatial: convention. Parameters ---------- @@ -318,19 +346,25 @@ def write_transform( transform : affine.Affine Transform to write **kwargs - Additional convention-specific parameters + Additional convention-specific parameters (e.g., grid_mapping_name for CF; + silently ignored here) Returns ------- xarray.Dataset or xarray.DataArray Object with transform written - - Raises - ------ - NotImplementedError - Zarr write support is not yet implemented """ - raise NotImplementedError( - "Zarr transform writing is not yet implemented. " - "Use Convention.CF for writing or wait for a future release." + add_convention_declaration(obj.attrs, "spatial:") + obj.attrs["spatial:transform"] = format_spatial_transform(transform) + y_dim = obj.rio.y_dim + x_dim = obj.rio.x_dim + height = obj.sizes[y_dim] + width = obj.sizes[x_dim] + obj.attrs["spatial:dimensions"] = [y_dim, x_dim] + obj.attrs["spatial:shape"] = [height, width] + left, bottom, right, top = rasterio.transform.array_bounds( + height, width, transform ) + obj.attrs["spatial:bbox"] = [left, bottom, right, top] + obj.attrs["spatial:registration"] = "pixel" + return obj diff --git a/test/integration/test_integration_zarr_conventions.py b/test/integration/test_integration_zarr_conventions.py index 926fa602..292d34b0 100644 --- a/test/integration/test_integration_zarr_conventions.py +++ b/test/integration/test_integration_zarr_conventions.py @@ -209,3 +209,54 @@ def test_read_proj_projjson(): crs = data.rio.crs assert crs is not None assert crs == CRS.from_epsg(4326) + + +# ============================================================================ +# Write tests +# ============================================================================ + + +def test_write_crs__zarr_convention(): + """Test writing CRS via Convention.ZARR produces correct proj: attributes.""" + data = xr.DataArray(np.random.rand(10, 20), dims=["y", "x"]) + result = data.rio.write_crs("EPSG:4326", convention=Convention.ZARR) + assert zarr.has_convention_declared(result.attrs, "proj:") is True + assert "proj:wkt2" in result.attrs + assert CRS.from_wkt(result.attrs["proj:wkt2"]) == CRS.from_epsg(4326) + + +def test_write_transform__zarr_convention(): + """Test writing transform via Convention.ZARR produces correct spatial: attributes.""" + transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + data = xr.DataArray(np.random.rand(10, 20), dims=["y", "x"]) + result = data.rio.write_transform(transform, convention=Convention.ZARR) + assert zarr.has_convention_declared(result.attrs, "spatial:") is True + assert result.attrs["spatial:transform"] == [1.0, 0.0, 0.0, 0.0, -1.0, 10.0] + assert result.attrs["spatial:dimensions"] == ["y", "x"] + assert result.attrs["spatial:shape"] == [10, 20] + assert "spatial:bbox" in result.attrs + assert result.attrs["spatial:registration"] == "pixel" + + +def test_write_crs__zarr_roundtrip(): + """Test that a CRS written with ZARR convention can be read back.""" + data = xr.DataArray(np.random.rand(10, 20), dims=["y", "x"]) + written = data.rio.write_crs("EPSG:4326", convention=Convention.ZARR) + assert written.rio.crs == CRS.from_epsg(4326) + + +def test_write_transform__zarr_roundtrip(): + """Test that a transform written with ZARR convention can be read back.""" + transform = Affine(1.0, 0.0, 0.0, 0.0, -1.0, 10.0) + data = xr.DataArray(np.random.rand(10, 20), dims=["y", "x"]) + written = data.rio.write_transform(transform, convention=Convention.ZARR) + assert written.rio._cached_transform() == transform + + +def test_write_crs__zarr_via_set_options(): + """Test writing CRS with Convention.ZARR set via set_options().""" + data = xr.DataArray(np.random.rand(10, 20), dims=["y", "x"]) + with set_options(convention=Convention.ZARR): + result = data.rio.write_crs("EPSG:4326") + assert zarr.has_convention_declared(result.attrs, "proj:") is True + assert "proj:wkt2" in result.attrs diff --git a/test/unit/test_convention_zarr.py b/test/unit/test_convention_zarr.py index 83f4f5b7..ba938c4b 100644 --- a/test/unit/test_convention_zarr.py +++ b/test/unit/test_convention_zarr.py @@ -4,6 +4,7 @@ from affine import Affine from rasterio.crs import CRS +import rioxarray # noqa: F401 from rioxarray._convention import zarr from rioxarray._convention.zarr import ZarrConvention @@ -163,3 +164,86 @@ def test_read_spatial_dimensions__no_convention_declared(): dims = ZarrConvention.read_spatial_dimensions(data) assert dims is None + + +# ============================================================================ +# Formatting utilities +# ============================================================================ + + +def test_format_spatial_transform(): + """Test converting Affine to [a, b, c, d, e, f] list.""" + affine = Affine(1.0, 0.0, 100.0, 0.0, -1.0, 200.0) + assert zarr.format_spatial_transform(affine) == [1.0, 0.0, 100.0, 0.0, -1.0, 200.0] + + +# ============================================================================ +# Convention declaration +# ============================================================================ + + +def test_add_convention_declaration(): + """Test adding a convention declaration to empty attrs.""" + attrs = {} + zarr.add_convention_declaration(attrs, "proj:") + assert zarr.has_convention_declared(attrs, "proj:") is True + + +def test_add_convention_declaration__idempotent(): + """Test that duplicate declarations are not added.""" + attrs = {} + zarr.add_convention_declaration(attrs, "proj:") + zarr.add_convention_declaration(attrs, "proj:") + assert len(attrs["zarr_conventions"]) == 1 + + +# ============================================================================ +# ZarrConvention.write_crs +# ============================================================================ + + +def test_write_crs(): + """Test writing CRS writes proj:wkt2 and declares convention.""" + data = xr.DataArray(np.random.rand(10, 10), dims=["y", "x"]) + crs = CRS.from_epsg(4326) + result = ZarrConvention.write_crs(data, crs=crs) + assert zarr.has_convention_declared(result.attrs, "proj:") is True + assert "proj:wkt2" in result.attrs + assert CRS.from_wkt(result.attrs["proj:wkt2"]) == crs + + +def test_write_crs__ignores_grid_mapping_name(): + """Test that grid_mapping_name kwarg (CF-specific) is silently ignored.""" + data = xr.DataArray(np.random.rand(10, 10), dims=["y", "x"]) + result = ZarrConvention.write_crs( + data, crs=CRS.from_epsg(4326), grid_mapping_name="spatial_ref" + ) + assert "proj:wkt2" in result.attrs + + +# ============================================================================ +# ZarrConvention.write_transform +# ============================================================================ + + +def test_write_transform(): + """Test writing transform writes all spatial: attributes.""" + data = xr.DataArray(np.random.rand(10, 20), dims=["y", "x"]) + transform = Affine(1.0, 0.0, 100.0, 0.0, -1.0, 200.0) + result = ZarrConvention.write_transform(data, transform=transform) + assert zarr.has_convention_declared(result.attrs, "spatial:") is True + assert result.attrs["spatial:transform"] == [1.0, 0.0, 100.0, 0.0, -1.0, 200.0] + assert result.attrs["spatial:dimensions"] == ["y", "x"] + assert result.attrs["spatial:shape"] == [10, 20] + assert "spatial:bbox" in result.attrs + assert result.attrs["spatial:registration"] == "pixel" + + +def test_write_transform__ignores_grid_mapping_name(): + """Test that grid_mapping_name kwarg (CF-specific) is silently ignored.""" + data = xr.DataArray(np.random.rand(10, 20), dims=["y", "x"]) + transform = Affine(1.0, 0.0, 100.0, 0.0, -1.0, 200.0) + result = ZarrConvention.write_transform( + data, transform=transform, grid_mapping_name="spatial_ref" + ) + assert "spatial:transform" in result.attrs