Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 151 additions & 12 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,9 @@ def open_geotiff(source, *, dtype=None, window=None,
# Dask path (CPU)
if chunks is not None:
return read_geotiff_dask(source, dtype=dtype, chunks=chunks,
overview_level=overview_level, name=name)
overview_level=overview_level,
window=window, band=band,
max_pixels=max_pixels, name=name)

kwargs = {}
if max_pixels is not None:
Expand Down Expand Up @@ -964,15 +966,24 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *,
"max_z_error is not supported on the GPU writer "
"(nvCOMP has no LERC backend). Use the CPU path "
"(gpu=False) or omit max_z_error.")
# Strip output is not implemented on the GPU path; reject up
# front rather than silently producing a tiled file.
if not tiled:
raise ValueError(
"tiled=False is not supported on the GPU writer. "
"Pass gpu=False or omit tiled=False.")
try:
write_geotiff_gpu(data, path, crs=crs, nodata=nodata,
compression=compression,
compression_level=compression_level,
tiled=tiled,
tile_size=tile_size,
predictor=predictor,
cog=cog,
overview_levels=overview_levels,
overview_resampling=overview_resampling)
overview_resampling=overview_resampling,
bigtiff=bigtiff,
streaming_buffer_bytes=streaming_buffer_bytes)
return
except (ImportError, Exception):
pass # fall through to CPU path
Expand Down Expand Up @@ -1379,6 +1390,9 @@ def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None,

def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
overview_level: int | None = None,
window: tuple | None = None,
band: int | None = None,
max_pixels: int | None = None,
name: str | None = None) -> xr.DataArray:
"""Read a GeoTIFF as a dask-backed DataArray for out-of-core processing.

Expand All @@ -1395,6 +1409,21 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
Chunk size in pixels. Default 512.
overview_level : int or None
Overview level (0 = full resolution).
window : tuple or None
``(row_start, col_start, row_stop, col_stop)`` to restrict
chunking to a sub-region of the file. Chunks are laid out
relative to the window origin. None reads the full raster.
band : int or None
Zero-based band index. None returns all bands (3D for
multi-band files, 2D for single-band). Selecting a single band
produces a 2D DataArray.
max_pixels : int or None
Maximum allowed pixel count (width * height * samples) for the
windowed region. None uses the reader default (~1 billion).
The cap is checked once up-front against the lazy region; each
chunk task also re-checks against ``max_pixels`` so windowed
reads stay bounded even when ``read_to_array`` is invoked
directly.
name : str or None
Name for the DataArray.

Expand Down Expand Up @@ -1478,14 +1507,68 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
else:
target_dtype = effective_dtype

coords = _geo_to_coords(geo_info, full_h, full_w)
# Window clipping: restrict the lazy region to the requested
# sub-rectangle. ``read_to_array`` already accepts ``window=`` per
# chunk; we only need to remap the chunk grid so its origin moves to
# ``(win_r0, win_c0)`` and its extent shrinks to the window.
win_r0 = win_c0 = 0
if window is not None:
win_r0, win_c0, win_r1, win_c1 = window
if (win_r0 < 0 or win_c0 < 0
or win_r1 > full_h or win_c1 > full_w
or win_r0 >= win_r1 or win_c0 >= win_c1):
raise ValueError(
f"window={window} is outside the source extent "
f"({full_h}x{full_w}) or has non-positive size.")
# Mirror the eager-path windowed coord computation in open_geotiff.
t = geo_info.transform
if geo_info.raster_type == RASTER_PIXEL_IS_POINT:
win_x = (np.arange(win_c0, win_c1, dtype=np.float64)
* t.pixel_width + t.origin_x)
win_y = (np.arange(win_r0, win_r1, dtype=np.float64)
* t.pixel_height + t.origin_y)
else:
win_x = (np.arange(win_c0, win_c1, dtype=np.float64)
* t.pixel_width + t.origin_x
+ t.pixel_width * 0.5)
win_y = (np.arange(win_r0, win_r1, dtype=np.float64)
* t.pixel_height + t.origin_y
+ t.pixel_height * 0.5)
coords = {'y': win_y, 'x': win_x}
full_h = win_r1 - win_r0
full_w = win_c1 - win_c0
else:
coords = _geo_to_coords(geo_info, full_h, full_w)

if band is not None:
if n_bands == 0:
if band != 0:
raise IndexError(
f"band={band} requested on a single-band file.")
elif not 0 <= band < n_bands:
raise IndexError(
f"band={band} out of range for {n_bands}-band file.")

# Up-front pixel-count guard against the windowed extent. Chunk
# tasks re-check via read_to_array's own ``max_pixels`` (which we
# forward through ``_delayed_read_window``), but catching an
# oversized request before any task is scheduled saves the caller
# from a misleading "tile size exceeds max_pixels" error in a
# chunk that happens to align with the file's tile grid.
if max_pixels is not None:
eff_bands = (1 if band is not None
else (n_bands if n_bands > 0 else 1))
if full_h * full_w * eff_bands > max_pixels:
raise ValueError(
f"Requested region {full_h}x{full_w}x{eff_bands} "
f"exceeds max_pixels={max_pixels:,}.")

if name is None:
import os
name = os.path.splitext(os.path.basename(source))[0]

attrs = {}
_populate_attrs_from_geo_info(attrs, geo_info)
_populate_attrs_from_geo_info(attrs, geo_info, window=window)
if nodata is not None:
attrs['nodata'] = nodata

Expand Down Expand Up @@ -1522,24 +1605,35 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,

# For multi-band, each window read returns (h, w, bands); for single-band (h, w)
# read_to_array with band=0 extracts a single band, band=None returns all
band_arg = None # return all bands (or 2D if single-band)
band_arg = band # None => all bands (or 2D for single-band file)

# When ``band`` is set, each chunk reads a 2D slice -- collapse the
# output dims so the returned DataArray is 2D regardless of file band
# count.
out_has_band_axis = band is None and n_bands > 0

dask_rows = []
for r0 in rows:
r1 = min(r0 + ch_h, full_h)
dask_cols = []
for c0 in cols:
c1 = min(c0 + ch_w, full_w)
if n_bands > 0:
if out_has_band_axis:
block_shape = (r1 - r0, c1 - c0, n_bands)
else:
block_shape = (r1 - r0, c1 - c0)
# Translate window-relative chunk coords back to file-relative
# coords for ``read_to_array``. ``win_r0`` / ``win_c0`` are 0
# when no window was requested.
block = da.from_delayed(
_delayed_read_window(source, r0, c0, r1, c1,
_delayed_read_window(source,
r0 + win_r0, c0 + win_c0,
r1 + win_r0, c1 + win_c0,
overview_level, nodata,
band_arg,
target_dtype=target_dtype if dtype is not None else None,
http_meta_key=http_meta_key),
http_meta_key=http_meta_key,
max_pixels=max_pixels),
shape=block_shape,
dtype=target_dtype,
)
Expand All @@ -1548,7 +1642,7 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,

dask_arr = da.concatenate(dask_rows, axis=0)

if n_bands > 0:
if out_has_band_axis:
dims = ['y', 'x', 'band']
coords['band'] = np.arange(n_bands)
else:
Expand All @@ -1560,7 +1654,8 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,


def _delayed_read_window(source, r0, c0, r1, c1, overview_level, nodata,
band, *, target_dtype=None, http_meta_key=None):
band, *, target_dtype=None, http_meta_key=None,
max_pixels=None):
"""Dask-delayed function to read a single window.

*http_meta_key* is an optional ``Delayed[(TIFFHeader, IFD)]`` parsed
Expand Down Expand Up @@ -1588,9 +1683,12 @@ def _read(http_meta):
and band is not None):
arr = arr[:, :, band]
else:
_r2a_kwargs = {}
if max_pixels is not None:
_r2a_kwargs['max_pixels'] = max_pixels
arr, _ = read_to_array(source, window=(r0, c0, r1, c1),
Comment on lines +1694 to 1697
overview_level=overview_level,
band=band)
band=band, **_r2a_kwargs)
if nodata is not None:
# ``arr`` was just decoded by ``_fetch_decode_cog_http_tiles``
# or ``read_to_array``; both return freshly-allocated buffers
Expand Down Expand Up @@ -2226,11 +2324,15 @@ def write_geotiff_gpu(data, path: str, *,
nodata=None,
compression: str = 'zstd',
compression_level: int | None = None,
tiled: bool = True,
tile_size: int = 256,
predictor: bool | int = False,
cog: bool = False,
overview_levels: list[int] | None = None,
overview_resampling: str = 'mean') -> None:
overview_resampling: str = 'mean',
bigtiff: bool | None = None,
max_z_error: float = 0.0,
streaming_buffer_bytes: int | None = None) -> None:
"""Write a CuPy-backed DataArray as a GeoTIFF with GPU compression.

Tiles are extracted and compressed on the GPU via nvCOMP, then
Expand Down Expand Up @@ -2260,6 +2362,12 @@ def write_geotiff_gpu(data, path: str, *,
compression_level : int or None
Compression effort level. Accepted for API compatibility but
currently ignored -- nvCOMP does not expose level control.
tiled : bool
Must be True (default). The GPU writer is tiled-only because
nvCOMP batch compression operates on per-tile streams; passing
``tiled=False`` raises ``ValueError`` rather than silently
producing a tiled file. Accepted for API parity with
``to_geotiff``.
tile_size : int
Tile size in pixels (default 256).
predictor : bool or int
Expand All @@ -2275,7 +2383,37 @@ def write_geotiff_gpu(data, path: str, *,
overview_resampling : str
Resampling method for overviews: 'mean' (default), 'nearest',
'min', 'max', 'median', or 'mode'.
bigtiff : bool or None
Force BigTIFF (64-bit offsets). None auto-promotes when the
estimated file size would exceed the classic-TIFF 4 GB limit.
max_z_error : float
Per-pixel error budget for LERC compression. The GPU writer
does not implement LERC (nvCOMP has no LERC backend), so any
non-zero value raises ``ValueError``. Accepted at the signature
level for API parity with ``to_geotiff``.
streaming_buffer_bytes : int or None
Accepted for API parity with ``to_geotiff``. The GPU writer
materialises the entire array on device and has no streaming
concept, so this kwarg is a no-op.
"""
if not tiled:
raise ValueError(
"write_geotiff_gpu requires tiled=True. nvCOMP batch "
"compression is tile-based; the strip layout is not "
"implemented on the GPU path. Use to_geotiff(..., gpu=False, "
"tiled=False) for strip output on CPU.")
if max_z_error < 0:
raise ValueError(
f"max_z_error must be >= 0, got {max_z_error}")
if max_z_error != 0:
raise ValueError(
"max_z_error is not supported on the GPU writer "
"(nvCOMP has no LERC backend). Use to_geotiff(..., gpu=False) "
"or omit max_z_error.")
# streaming_buffer_bytes is intentionally a no-op on the GPU path;
# the kwarg exists for API parity with to_geotiff so callers can pass
# the same kwargs to both entry points without filtering.
del streaming_buffer_bytes
try:
import cupy
except ImportError:
Expand Down Expand Up @@ -2444,6 +2582,7 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp):
x_resolution=x_res,
y_resolution=y_res,
resolution_unit=res_unit,
force_bigtiff=bigtiff,
)

_write_bytes(file_bytes, path)
Expand Down
Loading
Loading