Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
585 changes: 103 additions & 482 deletions docs/examples/convert_to_raster.ipynb

Large diffs are not rendered by default.

93 changes: 73 additions & 20 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from xarray.core.variable import as_variable

from rioxarray._spatial_utils import (
DEFAULT_GRID_MAP,
FILL_VALUE_NAMES,
UNWANTED_RIO_ATTRS,
UNWANTED_TAGS,
_generate_spatial_coords,
)
from rioxarray.exceptions import RioXarrayError
Expand Down Expand Up @@ -654,17 +656,56 @@ def build_subdataset_filter(
)


def _parse_and_sanitize_tags(
tags: dict, other_unwanted_attrs: Optional[tuple] = None
) -> dict:
attrs = _parse_tags(tags)

if other_unwanted_attrs is None:
other_unwanted_attrs = FILL_VALUE_NAMES + UNWANTED_RIO_ATTRS
else:
other_unwanted_attrs += FILL_VALUE_NAMES
other_unwanted_attrs += UNWANTED_RIO_ATTRS

# that should be added by GDAL/rasterio
for unwanted_attr in other_unwanted_attrs:
if unwanted_attr.endswith("*"):
for attr in attrs:
if attr.startswith(unwanted_attr[:-1]):
attrs.pop(unwanted_attr, None)
else:
attrs.pop(unwanted_attr, None)

return attrs


def _sanitize_netcdf_dims_in_attrs(riods, attrs, coords):
for coord in coords:
if f"NETCDF_DIM_{coord}" in attrs:
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}")
break
if f"NETCDF_DIM_{coord}_VALUES" in attrs:
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}_VALUES")
attrs.pop(f"NETCDF_DIM_{coord}_DEF", None)
attrs.pop("NETCDF_DIM_EXTRA", None)
break
else:
coord_name = "band"
coords[coord_name] = numpy.asarray(riods.indexes)

return coord_name


def _get_rasterio_attrs(riods: RasterioReader):
"""
Get rasterio specific attributes
"""
# pylint: disable=too-many-branches
# Add rasterio attributes
attrs = _parse_tags({**riods.tags(), **riods.tags(1)})
# remove attributes with informaiton
# that should be added by GDAL/rasterio
for unwanted_attr in FILL_VALUE_NAMES + UNWANTED_RIO_ATTRS:
attrs.pop(unwanted_attr, None)
# Add rasterio attributes (add the first band attributes into the raster's attribute)
attrs = _parse_and_sanitize_tags({**riods.tags(), **riods.tags(1)})

if riods.nodata is not None:
# The nodata values for the raster bands
attrs["_FillValue"] = riods.nodata
Expand Down Expand Up @@ -997,6 +1038,11 @@ def _single_band_open(*args, bidx=0, **kwargs):
)


def _remove_from_band_tags(band_tags: list[dict], to_be_removed: str):
for tags in band_tags:
tags.pop(to_be_removed, None)


def open_rasterio(
filename: Union[
str,
Expand Down Expand Up @@ -1196,20 +1242,18 @@ def open_rasterio(
attrs = _get_rasterio_attrs(riods=riods)
coords = _load_netcdf_1d_coords(attrs)
_parse_driver_tags(riods=riods, attrs=attrs, coords=coords)
for coord in coords:
if f"NETCDF_DIM_{coord}" in attrs:
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}")
break
if f"NETCDF_DIM_{coord}_VALUES" in attrs:
coord_name = coord
attrs.pop(f"NETCDF_DIM_{coord}_VALUES")
attrs.pop(f"NETCDF_DIM_{coord}_DEF", None)
attrs.pop("NETCDF_DIM_EXTRA", None)
break
else:
coord_name = "band"
coords[coord_name] = numpy.asarray(riods.indexes)
coord_name = _sanitize_netcdf_dims_in_attrs(riods=riods, attrs=attrs, coords=coords)

# Add band tags in band_tags as a list: one dict per band
band_tags = []
for i in range(riods.count):
tags = _parse_and_sanitize_tags(
riods.tags(i + 1),
other_unwanted_attrs=UNWANTED_TAGS
+ (DEFAULT_GRID_MAP, "AREA_OR_POINT", "NETCDF_*"),
)
if len(tags) > 0:
band_tags.append(tags)

# Handle GCPs and RPCs
has_gcps = riods.gcps[0]
Expand All @@ -1231,11 +1275,14 @@ def open_rasterio(
encoding: dict[Hashable, Any] = {}
if mask_and_scale and "_Unsigned" in attrs:
unsigned = variables.pop_to(attrs, encoding, "_Unsigned") == "true"
_remove_from_band_tags(band_tags, "_Unsigned")

if masked:
encoding["dtype"] = str(_rasterio_to_numpy_dtype(riods.dtypes))

da_name = attrs.pop("NETCDF_VARNAME", default_name)
_remove_from_band_tags(band_tags, "NETCDF_VARNAME")

data: Any = indexing.LazilyOuterIndexedArray(
RasterioArrayWrapper(
manager=manager,
Expand Down Expand Up @@ -1267,6 +1314,7 @@ def open_rasterio(
# make sure the _FillValue is correct dtype
if "_FillValue" in result.attrs:
result.attrs["_FillValue"] = result.dtype.type(result.attrs["_FillValue"])
_remove_from_band_tags(band_tags, "_FillValue")

# handle encoding
_handle_encoding(
Expand All @@ -1290,6 +1338,10 @@ def open_rasterio(
if has_rpcs:
result.rio.write_rpcs(riods.rpcs, inplace=True)

# Add band tags
if band_tags and band_tags[0]:
result.rio.write_band_tags(band_tags, inplace=True)

if chunks is not None:
result = _prepare_dask(
result=result,
Expand Down Expand Up @@ -1322,6 +1374,7 @@ def open_rasterio(
for attr, value in result.attrs.items()
if not attr.startswith(f"{result.name}#")
}

# Make the file closeable
result.set_close(manager.close)
result.rio._manager = manager
Expand Down
9 changes: 9 additions & 0 deletions rioxarray/_spatial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
FILL_VALUE_NAMES = ("_FillValue", "missing_value", "fill_value", "nodata")
UNWANTED_RIO_ATTRS = ("nodatavals", "is_tiled", "res")
DEFAULT_GRID_MAP = "spatial_ref"
UNWANTED_TAGS = (
"crs",
"transform",
"scales",
"scale_factor",
"add_offset",
"offsets",
"grid_mapping",
)

# DTYPE TO NODATA MAP
# Based on: https://github.com/OSGeo/gdal/blob/v3.12.1/swig/python/gdal-utils/osgeo_utils/gdal_calc.py#L49-L66
Expand Down
52 changes: 50 additions & 2 deletions rioxarray/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def to_raster(
*,
driver: Optional[str] = None,
dtype: Optional[Union[str, numpy.dtype]] = None,
tags: Optional[dict[str, str]] = None,
tags: Optional[dict[str, str | list]] = None,
windowed: bool = False,
recalc_transform: bool = True,
lock: Optional[bool] = None,
Expand Down Expand Up @@ -1058,6 +1058,15 @@ def to_raster(
self.encoded_nodata if self.encoded_nodata is not None else self.nodata
)

# Add band tags
band_tags = self.get_band_tags()

if band_tags:
if tags is None:
tags = {"band_tags": band_tags}
else:
tags["band_tags"] = band_tags

return RasterioWriter(raster_path=raster_path).to_raster(
xarray_dataarray=self._obj,
tags=tags,
Expand Down Expand Up @@ -1091,11 +1100,50 @@ def to_rasterio_dataset(self) -> Generator[DatasetReader, None, None]:
Example
-------

>>> with xds.to_rasterio_dataset() as rio_ds:
>>> with xda.rio.to_rasterio_dataset() as rio_ds:
>>> rio_ds.count

"""
with MemoryFile() as memfile:
self.to_raster(memfile.name)
with memfile.open() as src_ds:
yield src_ds

def write_band_tags(
self, band_tags: list[dict], inplace: bool = False
) -> xarray.DataArray:
"""
Write band tags to the :obj:`xarray.DataArray`'s attributes, ensuring one tag per band.

The tags are stored in the array's attributes under the key :code:`"band_tags"`, ensuring they'll be written on disk with :func:`to_raster`.

Parameters
----------
band_tags: list[dict]
A list of dictionnaries, one per band, containing the bands' metadata.

Returns
-------
:obj:`xarray.DataArray`:
Modified DataArray with band tags

Raises
------
AssertionError:
If the length of `band_tags` does not match the number of bands.

Example
-------

>>> band_tags = [
>>> {"year": "yesterday", "where": "here"},
>>> {"year": "now", "where": "here"}
>>> ], # Raster has two bands
>>> xda.rio.write_band_tags(band_tags, inplace=True)
"""
assert len(band_tags) == self.count, "You should give one band tag per band."

data_obj: xarray.DataArray = self._get_obj(inplace=inplace) # type: ignore
data_obj.rio._band_tags = band_tags
data_obj.encoding["band_tags"] = band_tags
return data_obj
Loading
Loading