Skip to content

Commit 71b0c23

Browse files
authored
Streaming TIFF write for dask inputs (#1084) (#1108)
* Add streaming TIFF write for dask inputs (#1084) to_geotiff() no longer calls .compute() on dask-backed DataArrays. Instead it writes one tile-row at a time: compute the row from the dask graph, compress each tile, write to disk, move on. A second pass patches the IFD offsets and byte-counts. Peak memory is now roughly tile_height * width * bytes_per_sample instead of the full array. Works with tiled and stripped layouts, all compression codecs. COG still materializes because overviews need the full array; for large-raster COGs, use VRT output (#1083). * Add tests for streaming TIFF write (#1084) 19 tests covering round-trip correctness (zstd, deflate, lzw, uncompressed, stripped, predictor, compression_level), geo metadata preservation (CRS, nodata, coordinates), edge cases (NaN handling, single chunk, uneven chunks, small raster, multiple dtypes), and COG fallback to the eager path. * Update docstring and README for streaming dask write (#1084) Document the streaming behavior in to_geotiff's docstring and add a usage example to the README showing dask-to-single-TIFF output. * Add user guide notebook for streaming GeoTIFF write (#1084) Demonstrates the three dask write modes: streaming to a single TIFF, streaming to VRT, and eager COG fallback. Includes compression comparison and a summary table of when to use each mode. * Address review: add multiband tests, BigTIFF comment, fail-fast URI check (#1084) - Add 3D band-last and band-first streaming write tests - Add forced bigtiff=True round-trip test - Add cloud URI rejection test - Note the uint32 offset limitation for BigTIFF files > 4 GB - Move fsspec URI check to top of write_streaming for fail-fast
1 parent ccfa6bd commit 71b0c23

File tree

5 files changed

+952
-16
lines changed

5 files changed

+952
-16
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ open_geotiff('dem.tif', dtype='float32') # half memory
169169
open_geotiff('dem.tif', dtype='float32', chunks=512) # Dask + half memory
170170
to_geotiff(data, 'out.tif', compression_level=1) # fast scratch write
171171
to_geotiff(data, 'out.tif', compression_level=22) # max compression
172+
to_geotiff(dask_da, 'out.tif') # stream Dask to single TIFF
172173
to_geotiff(dask_da, 'mosaic.vrt') # stream Dask to VRT
173174

174175
# Accessor methods
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Streaming GeoTIFF write from dask arrays\n",
8+
"\n",
9+
"When you call `to_geotiff()` on a dask-backed DataArray, the data is written one tile-row at a time. Only one tile-row lives in memory at once, so you can write rasters larger than RAM without switching to VRT output.\n",
10+
"\n",
11+
"This notebook shows the three write modes for dask data:\n",
12+
"1. **Streaming to a single TIFF** (automatic when the input is dask-backed)\n",
13+
"2. **Streaming to a VRT** (one file per chunk, stitched by an XML index)\n",
14+
"3. **Eager write** (materialise first, then write; needed for COG with overviews)"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"%matplotlib inline\n",
24+
"import tempfile\n",
25+
"import os\n",
26+
"\n",
27+
"import numpy as np\n",
28+
"import xarray as xr\n",
29+
"import dask.array as da\n",
30+
"import matplotlib.pyplot as plt\n",
31+
"\n",
32+
"from xrspatial.geotiff import open_geotiff, to_geotiff"
33+
]
34+
},
35+
{
36+
"cell_type": "markdown",
37+
"metadata": {},
38+
"source": [
39+
"## Build a dask-backed raster\n",
40+
"\n",
41+
"A 2000x2000 terrain surface chunked into 500x500 blocks. Four chunks along each axis, sixteen chunks total."
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": null,
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"rng = np.random.default_rng(1084)\n",
51+
"H, W = 2000, 2000\n",
52+
"\n",
53+
"yy, xx = np.meshgrid(\n",
54+
" np.linspace(0, 6 * np.pi, H),\n",
55+
" np.linspace(0, 6 * np.pi, W),\n",
56+
" indexing='ij',\n",
57+
")\n",
58+
"terrain = (500 + 200 * np.sin(yy) * np.cos(xx * 0.7)\n",
59+
" + 30 * rng.standard_normal((H, W))).astype(np.float32)\n",
60+
"\n",
61+
"y = np.linspace(45.0, 44.0, H)\n",
62+
"x = np.linspace(-122.0, -121.0, W)\n",
63+
"\n",
64+
"raster = xr.DataArray(\n",
65+
" terrain, dims=['y', 'x'],\n",
66+
" coords={'y': y, 'x': x},\n",
67+
" attrs={'crs': 4326, 'nodata': -9999.0},\n",
68+
")\n",
69+
"\n",
70+
"dask_raster = raster.chunk({'y': 500, 'x': 500})\n",
71+
"print(f'Shape: {dask_raster.shape}')\n",
72+
"print(f'Chunks: {dask_raster.chunks}')\n",
73+
"print(f'dtype: {dask_raster.dtype}')"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"fig, ax = plt.subplots(figsize=(6, 6))\n",
83+
"raster.plot.imshow(ax=ax, cmap='terrain', add_colorbar=True)\n",
84+
"ax.set_title('Synthetic terrain (2000x2000)')\n",
85+
"ax.set_axis_off()\n",
86+
"plt.tight_layout()\n",
87+
"plt.show()"
88+
]
89+
},
90+
{
91+
"cell_type": "markdown",
92+
"metadata": {},
93+
"source": [
94+
"## 1. Streaming write to a single TIFF\n",
95+
"\n",
96+
"Pass the dask-backed DataArray to `to_geotiff()` the same way you would a numpy array. The streaming path kicks in automatically."
97+
]
98+
},
99+
{
100+
"cell_type": "code",
101+
"execution_count": null,
102+
"metadata": {},
103+
"outputs": [],
104+
"source": [
105+
"tmpdir = tempfile.mkdtemp(prefix='xrs_stream_nb_')\n",
106+
"\n",
107+
"tif_path = os.path.join(tmpdir, 'streamed.tif')\n",
108+
"to_geotiff(dask_raster, tif_path)\n",
109+
"\n",
110+
"print(f'File size: {os.path.getsize(tif_path):,} bytes')\n",
111+
"\n",
112+
"# Read back and verify\n",
113+
"loaded = open_geotiff(tif_path)\n",
114+
"print(f'Shape: {loaded.shape}')\n",
115+
"print(f'CRS: {loaded.attrs.get(\"crs\")}')\n",
116+
"print(f'Match: {np.allclose(loaded.values, raster.values)}')"
117+
]
118+
},
119+
{
120+
"cell_type": "markdown",
121+
"metadata": {},
122+
"source": [
123+
"That's it. Same API, same output, but peak memory was roughly `tile_size * width * 4 bytes` instead of the full 2000x2000 array."
124+
]
125+
},
126+
{
127+
"cell_type": "markdown",
128+
"metadata": {},
129+
"source": [
130+
"## 2. Streaming write to a VRT\n",
131+
"\n",
132+
"If you want one tile per dask chunk (useful when chunks are large or you plan to read subregions later), write to a `.vrt` path instead."
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"metadata": {},
139+
"outputs": [],
140+
"source": [
141+
"vrt_path = os.path.join(tmpdir, 'tiled.vrt')\n",
142+
"to_geotiff(dask_raster, vrt_path)\n",
143+
"\n",
144+
"tiles_dir = os.path.join(tmpdir, 'tiled_tiles')\n",
145+
"tile_files = sorted(os.listdir(tiles_dir))\n",
146+
"print(f'VRT file: {os.path.getsize(vrt_path):,} bytes')\n",
147+
"print(f'Tile count: {len(tile_files)}')\n",
148+
"print(f'Tiles: {tile_files}')\n",
149+
"\n",
150+
"mosaic = open_geotiff(vrt_path)\n",
151+
"print(f'\\nMosaic shape: {mosaic.shape}')\n",
152+
"print(f'Match: {np.allclose(mosaic.values, raster.values)}')"
153+
]
154+
},
155+
{
156+
"cell_type": "markdown",
157+
"metadata": {},
158+
"source": [
159+
"Four chunks along each axis produces 16 tile files, stitched by a lightweight XML index."
160+
]
161+
},
162+
{
163+
"cell_type": "markdown",
164+
"metadata": {},
165+
"source": [
166+
"## 3. Compression and layout options\n",
167+
"\n",
168+
"All `to_geotiff` options work with the streaming path. Try different codecs and see the file size difference."
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"metadata": {},
175+
"outputs": [],
176+
"source": [
177+
"codecs = ['none', 'deflate', 'zstd', 'lzw']\n",
178+
"sizes = {}\n",
179+
"\n",
180+
"for codec in codecs:\n",
181+
" p = os.path.join(tmpdir, f'stream_{codec}.tif')\n",
182+
" to_geotiff(dask_raster, p, compression=codec)\n",
183+
" sizes[codec] = os.path.getsize(p)\n",
184+
"\n",
185+
"for codec, sz in sizes.items():\n",
186+
" ratio = sz / sizes['none']\n",
187+
" print(f'{codec:>8s}: {sz:>12,} bytes ({ratio:.2%} of uncompressed)')"
188+
]
189+
},
190+
{
191+
"cell_type": "markdown",
192+
"metadata": {},
193+
"source": [
194+
"## 4. When streaming doesn't apply\n",
195+
"\n",
196+
"COG output with `cog=True` needs overviews, which are built from the full array. In that case `to_geotiff` falls through to the eager path and calls `.compute()` as before."
197+
]
198+
},
199+
{
200+
"cell_type": "code",
201+
"execution_count": null,
202+
"metadata": {},
203+
"outputs": [],
204+
"source": [
205+
"cog_path = os.path.join(tmpdir, 'eager_cog.tif')\n",
206+
"to_geotiff(dask_raster, cog_path, cog=True)\n",
207+
"\n",
208+
"print(f'COG size: {os.path.getsize(cog_path):,} bytes')\n",
209+
"cog = open_geotiff(cog_path)\n",
210+
"print(f'Match: {np.allclose(cog.values, raster.values)}')"
211+
]
212+
},
213+
{
214+
"cell_type": "markdown",
215+
"metadata": {},
216+
"source": [
217+
"If the full array doesn't fit in memory, use VRT output instead of COG."
218+
]
219+
},
220+
{
221+
"cell_type": "code",
222+
"execution_count": null,
223+
"metadata": {},
224+
"outputs": [],
225+
"source": [
226+
"import shutil\n",
227+
"shutil.rmtree(tmpdir, ignore_errors=True)"
228+
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"metadata": {},
233+
"source": [
234+
"### Summary\n",
235+
"\n",
236+
"| Write mode | Path | Peak memory | When to use |\n",
237+
"|:-----------|:-----|:------------|:------------|\n",
238+
"| Streaming TIFF | `out.tif` | ~1 tile-row | Default for dask input |\n",
239+
"| Streaming VRT | `out.vrt` | ~1 chunk | Need per-chunk files |\n",
240+
"| Eager (COG) | `out.tif`, `cog=True` | Full array | Need overviews |"
241+
]
242+
}
243+
],
244+
"metadata": {
245+
"kernelspec": {
246+
"display_name": "Python 3",
247+
"language": "python",
248+
"name": "python3"
249+
},
250+
"language_info": {
251+
"name": "python",
252+
"version": "3.11.0"
253+
}
254+
},
255+
"nbformat": 4,
256+
"nbformat_minor": 4
257+
}

xrspatial/geotiff/__init__.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
387387
gpu: bool | None = None) -> None:
388388
"""Write data as a GeoTIFF or Cloud Optimized GeoTIFF.
389389
390+
Dask-backed DataArrays are written in streaming mode: one tile-row
391+
at a time, without materialising the full array into RAM. Peak
392+
memory is roughly ``tile_size * width * bytes_per_sample``. COG
393+
output (``cog=True``) still materialises because overviews need the
394+
full array.
395+
390396
Automatically dispatches to GPU compression when:
391397
- ``gpu=True`` is passed, or
392398
- The input data is CuPy-backed (auto-detected)
@@ -483,25 +489,14 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
483489
wkt_fallback = crs
484490

485491
if isinstance(data, xr.DataArray):
486-
# Handle CuPy-backed DataArrays: convert to numpy for CPU write
487492
raw = data.data
488-
if hasattr(raw, 'get'):
489-
arr = raw.get() # CuPy -> numpy
490-
elif hasattr(raw, 'compute'):
491-
arr = raw.compute() # Dask -> numpy
492-
if hasattr(arr, 'get'):
493-
arr = arr.get() # Dask+CuPy -> numpy
494-
else:
495-
arr = np.asarray(raw)
496-
# Handle band-first dimension order (band, y, x) -> (y, x, band)
497-
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
498-
arr = np.moveaxis(arr, 0, -1)
493+
494+
# Extract metadata from DataArray attrs (no materialisation needed)
499495
if geo_transform is None:
500496
geo_transform = _coords_to_transform(data)
501497
if epsg is None and crs is None:
502498
crs_attr = data.attrs.get('crs')
503499
if isinstance(crs_attr, str):
504-
# WKT string from reproject() or other source
505500
epsg = _wkt_to_epsg(crs_attr)
506501
if epsg is None and wkt_fallback is None:
507502
wkt_fallback = crs_attr
@@ -517,22 +512,75 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
517512
nodata = data.attrs.get('nodata')
518513
if data.attrs.get('raster_type') == 'point':
519514
raster_type = RASTER_PIXEL_IS_POINT
520-
# GDAL metadata from attrs (prefer raw XML, fall back to dict)
521515
gdal_meta_xml = data.attrs.get('gdal_metadata_xml')
522516
if gdal_meta_xml is None:
523517
gdal_meta_dict = data.attrs.get('gdal_metadata')
524518
if isinstance(gdal_meta_dict, dict):
525519
from ._geotags import _build_gdal_metadata_xml
526520
gdal_meta_xml = _build_gdal_metadata_xml(gdal_meta_dict)
527-
# Extra tags for pass-through
528521
extra_tags_list = data.attrs.get('extra_tags')
529-
# Resolution / DPI from attrs
530522
x_res = data.attrs.get('x_resolution')
531523
y_res = data.attrs.get('y_resolution')
532524
unit_str = data.attrs.get('resolution_unit')
533525
if unit_str is not None:
534526
_unit_ids = {'none': 1, 'inch': 2, 'centimeter': 3}
535527
res_unit = _unit_ids.get(str(unit_str), None)
528+
529+
# Dask-backed: stream tiles to avoid materialising the full array.
530+
# COG requires overviews from the full array, so it falls through
531+
# to the eager path.
532+
if hasattr(raw, 'dask') and not cog:
533+
dask_arr = raw
534+
# Handle band-first dimension order (band, y, x) -> (y, x, band)
535+
if raw.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
536+
import dask.array as da
537+
dask_arr = da.moveaxis(raw, 0, -1)
538+
if dask_arr.ndim not in (2, 3):
539+
raise ValueError(
540+
f"Expected 2D or 3D array, got {dask_arr.ndim}D")
541+
# Validate compression_level
542+
if compression_level is not None:
543+
level_range = _LEVEL_RANGES.get(compression.lower())
544+
if level_range is not None:
545+
lo, hi = level_range
546+
if not (lo <= compression_level <= hi):
547+
raise ValueError(
548+
f"compression_level={compression_level} out of "
549+
f"range for {compression} (valid: {lo}-{hi})")
550+
from ._writer import write_streaming
551+
write_streaming(
552+
dask_arr, path,
553+
geo_transform=geo_transform,
554+
crs_epsg=epsg,
555+
crs_wkt=wkt_fallback if epsg is None else None,
556+
nodata=nodata,
557+
compression=compression,
558+
compression_level=compression_level,
559+
tiled=tiled,
560+
tile_size=tile_size,
561+
predictor=predictor,
562+
raster_type=raster_type,
563+
x_resolution=x_res,
564+
y_resolution=y_res,
565+
resolution_unit=res_unit,
566+
gdal_metadata_xml=gdal_meta_xml,
567+
extra_tags=extra_tags_list,
568+
bigtiff=bigtiff,
569+
)
570+
return
571+
572+
# Eager compute (numpy, CuPy, or dask+COG)
573+
if hasattr(raw, 'get'):
574+
arr = raw.get() # CuPy -> numpy
575+
elif hasattr(raw, 'compute'):
576+
arr = raw.compute() # Dask -> numpy
577+
if hasattr(arr, 'get'):
578+
arr = arr.get() # Dask+CuPy -> numpy
579+
else:
580+
arr = np.asarray(raw)
581+
# Handle band-first dimension order (band, y, x) -> (y, x, band)
582+
if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'):
583+
arr = np.moveaxis(arr, 0, -1)
536584
else:
537585
if hasattr(data, 'get'):
538586
arr = data.get() # CuPy -> numpy

0 commit comments

Comments
 (0)