Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ open_geotiff('dem.tif', dtype='float32') # half memory
open_geotiff('dem.tif', dtype='float32', chunks=512) # Dask + half memory
to_geotiff(data, 'out.tif', compression_level=1) # fast scratch write
to_geotiff(data, 'out.tif', compression_level=22) # max compression
to_geotiff(dask_da, 'out.tif') # stream Dask to single TIFF
to_geotiff(dask_da, 'mosaic.vrt') # stream Dask to VRT

# Accessor methods
Expand Down
257 changes: 257 additions & 0 deletions examples/user_guide/47_Streaming_GeoTIFF_Write.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Streaming GeoTIFF write from dask arrays\n",
"\n",
"When you call `to_geotiff()` on a dask-backed DataArray, the data is written one tile-row at a time. Only one tile-row lives in memory at once, so you can write rasters larger than RAM without switching to VRT output.\n",
"\n",
"This notebook shows the three write modes for dask data:\n",
"1. **Streaming to a single TIFF** (automatic when the input is dask-backed)\n",
"2. **Streaming to a VRT** (one file per chunk, stitched by an XML index)\n",
"3. **Eager write** (materialise first, then write; needed for COG with overviews)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import tempfile\n",
"import os\n",
"\n",
"import numpy as np\n",
"import xarray as xr\n",
"import dask.array as da\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from xrspatial.geotiff import open_geotiff, to_geotiff"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build a dask-backed raster\n",
"\n",
"A 2000x2000 terrain surface chunked into 500x500 blocks. Four chunks along each axis, sixteen chunks total."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rng = np.random.default_rng(1084)\n",
"H, W = 2000, 2000\n",
"\n",
"yy, xx = np.meshgrid(\n",
" np.linspace(0, 6 * np.pi, H),\n",
" np.linspace(0, 6 * np.pi, W),\n",
" indexing='ij',\n",
")\n",
"terrain = (500 + 200 * np.sin(yy) * np.cos(xx * 0.7)\n",
" + 30 * rng.standard_normal((H, W))).astype(np.float32)\n",
"\n",
"y = np.linspace(45.0, 44.0, H)\n",
"x = np.linspace(-122.0, -121.0, W)\n",
"\n",
"raster = xr.DataArray(\n",
" terrain, dims=['y', 'x'],\n",
" coords={'y': y, 'x': x},\n",
" attrs={'crs': 4326, 'nodata': -9999.0},\n",
")\n",
"\n",
"dask_raster = raster.chunk({'y': 500, 'x': 500})\n",
"print(f'Shape: {dask_raster.shape}')\n",
"print(f'Chunks: {dask_raster.chunks}')\n",
"print(f'dtype: {dask_raster.dtype}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(6, 6))\n",
"raster.plot.imshow(ax=ax, cmap='terrain', add_colorbar=True)\n",
"ax.set_title('Synthetic terrain (2000x2000)')\n",
"ax.set_axis_off()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Streaming write to a single TIFF\n",
"\n",
"Pass the dask-backed DataArray to `to_geotiff()` the same way you would a numpy array. The streaming path kicks in automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tmpdir = tempfile.mkdtemp(prefix='xrs_stream_nb_')\n",
"\n",
"tif_path = os.path.join(tmpdir, 'streamed.tif')\n",
"to_geotiff(dask_raster, tif_path)\n",
"\n",
"print(f'File size: {os.path.getsize(tif_path):,} bytes')\n",
"\n",
"# Read back and verify\n",
"loaded = open_geotiff(tif_path)\n",
"print(f'Shape: {loaded.shape}')\n",
"print(f'CRS: {loaded.attrs.get(\"crs\")}')\n",
"print(f'Match: {np.allclose(loaded.values, raster.values)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's it. Same API, same output, but peak memory was roughly `tile_size * width * 4 bytes` instead of the full 2000x2000 array."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Streaming write to a VRT\n",
"\n",
"If you want one tile per dask chunk (useful when chunks are large or you plan to read subregions later), write to a `.vrt` path instead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vrt_path = os.path.join(tmpdir, 'tiled.vrt')\n",
"to_geotiff(dask_raster, vrt_path)\n",
"\n",
"tiles_dir = os.path.join(tmpdir, 'tiled_tiles')\n",
"tile_files = sorted(os.listdir(tiles_dir))\n",
"print(f'VRT file: {os.path.getsize(vrt_path):,} bytes')\n",
"print(f'Tile count: {len(tile_files)}')\n",
"print(f'Tiles: {tile_files}')\n",
"\n",
"mosaic = open_geotiff(vrt_path)\n",
"print(f'\\nMosaic shape: {mosaic.shape}')\n",
"print(f'Match: {np.allclose(mosaic.values, raster.values)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Four chunks along each axis produces 16 tile files, stitched by a lightweight XML index."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Compression and layout options\n",
"\n",
"All `to_geotiff` options work with the streaming path. Try different codecs and see the file size difference."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"codecs = ['none', 'deflate', 'zstd', 'lzw']\n",
"sizes = {}\n",
"\n",
"for codec in codecs:\n",
" p = os.path.join(tmpdir, f'stream_{codec}.tif')\n",
" to_geotiff(dask_raster, p, compression=codec)\n",
" sizes[codec] = os.path.getsize(p)\n",
"\n",
"for codec, sz in sizes.items():\n",
" ratio = sz / sizes['none']\n",
" print(f'{codec:>8s}: {sz:>12,} bytes ({ratio:.2%} of uncompressed)')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. When streaming doesn't apply\n",
"\n",
"COG output with `cog=True` needs overviews, which are built from the full array. In that case `to_geotiff` falls through to the eager path and calls `.compute()` as before."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cog_path = os.path.join(tmpdir, 'eager_cog.tif')\n",
"to_geotiff(dask_raster, cog_path, cog=True)\n",
"\n",
"print(f'COG size: {os.path.getsize(cog_path):,} bytes')\n",
"cog = open_geotiff(cog_path)\n",
"print(f'Match: {np.allclose(cog.values, raster.values)}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the full array doesn't fit in memory, use VRT output instead of COG."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import shutil\n",
"shutil.rmtree(tmpdir, ignore_errors=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Summary\n",
"\n",
"| Write mode | Path | Peak memory | When to use |\n",
"|:-----------|:-----|:------------|:------------|\n",
"| Streaming TIFF | `out.tif` | ~1 tile-row | Default for dask input |\n",
"| Streaming VRT | `out.vrt` | ~1 chunk | Need per-chunk files |\n",
"| Eager (COG) | `out.tif`, `cog=True` | Full array | Need overviews |"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
80 changes: 64 additions & 16 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
gpu: bool | None = None) -> None:
"""Write data as a GeoTIFF or Cloud Optimized GeoTIFF.

Dask-backed DataArrays are written in streaming mode: one tile-row
at a time, without materialising the full array into RAM. Peak
memory is roughly ``tile_size * width * bytes_per_sample``. COG
output (``cog=True``) still materialises because overviews need the
full array.

Automatically dispatches to GPU compression when:
- ``gpu=True`` is passed, or
- The input data is CuPy-backed (auto-detected)
Expand Down Expand Up @@ -483,25 +489,14 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
wkt_fallback = crs

if isinstance(data, xr.DataArray):
# Handle CuPy-backed DataArrays: convert to numpy for CPU write
raw = data.data
if hasattr(raw, 'get'):
arr = raw.get() # CuPy -> numpy
elif hasattr(raw, 'compute'):
arr = raw.compute() # Dask -> numpy
if hasattr(arr, 'get'):
arr = arr.get() # Dask+CuPy -> numpy
else:
arr = np.asarray(raw)
# Handle band-first dimension order (band, y, x) -> (y, x, band)
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
arr = np.moveaxis(arr, 0, -1)

# Extract metadata from DataArray attrs (no materialisation needed)
if geo_transform is None:
geo_transform = _coords_to_transform(data)
if epsg is None and crs is None:
crs_attr = data.attrs.get('crs')
if isinstance(crs_attr, str):
# WKT string from reproject() or other source
epsg = _wkt_to_epsg(crs_attr)
if epsg is None and wkt_fallback is None:
wkt_fallback = crs_attr
Expand All @@ -517,22 +512,75 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
nodata = data.attrs.get('nodata')
if data.attrs.get('raster_type') == 'point':
raster_type = RASTER_PIXEL_IS_POINT
# GDAL metadata from attrs (prefer raw XML, fall back to dict)
gdal_meta_xml = data.attrs.get('gdal_metadata_xml')
if gdal_meta_xml is None:
gdal_meta_dict = data.attrs.get('gdal_metadata')
if isinstance(gdal_meta_dict, dict):
from ._geotags import _build_gdal_metadata_xml
gdal_meta_xml = _build_gdal_metadata_xml(gdal_meta_dict)
# Extra tags for pass-through
extra_tags_list = data.attrs.get('extra_tags')
# Resolution / DPI from attrs
x_res = data.attrs.get('x_resolution')
y_res = data.attrs.get('y_resolution')
unit_str = data.attrs.get('resolution_unit')
if unit_str is not None:
_unit_ids = {'none': 1, 'inch': 2, 'centimeter': 3}
res_unit = _unit_ids.get(str(unit_str), None)

# Dask-backed: stream tiles to avoid materialising the full array.
# COG requires overviews from the full array, so it falls through
# to the eager path.
if hasattr(raw, 'dask') and not cog:
dask_arr = raw
# Handle band-first dimension order (band, y, x) -> (y, x, band)
if raw.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
import dask.array as da
dask_arr = da.moveaxis(raw, 0, -1)
if dask_arr.ndim not in (2, 3):
raise ValueError(
f"Expected 2D or 3D array, got {dask_arr.ndim}D")
# Validate compression_level
if compression_level is not None:
level_range = _LEVEL_RANGES.get(compression.lower())
if level_range is not None:
lo, hi = level_range
if not (lo <= compression_level <= hi):
raise ValueError(
f"compression_level={compression_level} out of "
f"range for {compression} (valid: {lo}-{hi})")
from ._writer import write_streaming
write_streaming(
dask_arr, path,
geo_transform=geo_transform,
crs_epsg=epsg,
crs_wkt=wkt_fallback if epsg is None else None,
nodata=nodata,
compression=compression,
compression_level=compression_level,
tiled=tiled,
tile_size=tile_size,
predictor=predictor,
raster_type=raster_type,
x_resolution=x_res,
y_resolution=y_res,
resolution_unit=res_unit,
gdal_metadata_xml=gdal_meta_xml,
extra_tags=extra_tags_list,
bigtiff=bigtiff,
)
return

# Eager compute (numpy, CuPy, or dask+COG)
if hasattr(raw, 'get'):
arr = raw.get() # CuPy -> numpy
elif hasattr(raw, 'compute'):
arr = raw.compute() # Dask -> numpy
if hasattr(arr, 'get'):
arr = arr.get() # Dask+CuPy -> numpy
else:
arr = np.asarray(raw)
# Handle band-first dimension order (band, y, x) -> (y, x, band)
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
arr = np.moveaxis(arr, 0, -1)
else:
if hasattr(data, 'get'):
arr = data.get() # CuPy -> numpy
Expand Down
Loading
Loading