Skip to content

Commit 98b0ae9

Browse files
committed
Add GPU COG overview support to write_geotiff_gpu (#1150)
write_geotiff_gpu now accepts cog, overview_levels, and overview_resampling. Overviews are block-reduced and compressed on GPU, then assembled into COG layout. to_geotiff() passes the new params through to the GPU path.
1 parent 46ee269 commit 98b0ae9

File tree

2 files changed

+135
-22
lines changed

2 files changed

+135
-22
lines changed

xrspatial/geotiff/__init__.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,10 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *,
465465
compression=compression,
466466
compression_level=compression_level,
467467
tile_size=tile_size,
468-
predictor=predictor)
468+
predictor=predictor,
469+
cog=cog,
470+
overview_levels=overview_levels,
471+
overview_resampling=overview_resampling)
469472
return
470473
except (ImportError, Exception):
471474
pass # fall through to CPU path
@@ -1154,14 +1157,21 @@ def write_geotiff_gpu(data, path: str, *,
11541157
compression: str = 'zstd',
11551158
compression_level: int | None = None,
11561159
tile_size: int = 256,
1157-
predictor: bool = False) -> None:
1160+
predictor: bool = False,
1161+
cog: bool = False,
1162+
overview_levels: list[int] | None = None,
1163+
overview_resampling: str = 'mean') -> None:
11581164
"""Write a CuPy-backed DataArray as a GeoTIFF with GPU compression.
11591165
11601166
Tiles are extracted and compressed on the GPU via nvCOMP, then
11611167
assembled into a TIFF file on CPU. The CuPy array stays on device
11621168
throughout compression -- only the compressed bytes transfer to CPU
11631169
for file writing.
11641170
1171+
When ``cog=True``, generates overview pyramids on GPU and writes a
1172+
Cloud Optimized GeoTIFF with all IFDs at the file start for
1173+
efficient range-request access.
1174+
11651175
Falls back to CPU compression if nvCOMP is not available.
11661176
11671177
Parameters
@@ -1184,13 +1194,22 @@ def write_geotiff_gpu(data, path: str, *,
11841194
Tile size in pixels (default 256).
11851195
predictor : bool
11861196
Apply horizontal differencing predictor.
1197+
cog : bool
1198+
Write as Cloud Optimized GeoTIFF with overviews.
1199+
overview_levels : list[int] or None
1200+
Overview decimation factors (e.g. [2, 4, 8]). Only used when
1201+
cog=True. If None and cog=True, auto-generates levels by
1202+
halving until the smallest overview fits in a single tile.
1203+
overview_resampling : str
1204+
Resampling method for overviews: 'mean' (default), 'nearest',
1205+
'min', 'max', 'median', or 'mode'.
11871206
"""
11881207
try:
11891208
import cupy
11901209
except ImportError:
11911210
raise ImportError("cupy is required for GPU writes")
11921211

1193-
from ._gpu_decode import gpu_compress_tiles
1212+
from ._gpu_decode import gpu_compress_tiles, make_overview_gpu
11941213
from ._writer import (
11951214
_compression_tag, _assemble_tiff, _write_bytes,
11961215
GeoTransform as _GT,
@@ -1245,28 +1264,45 @@ def write_geotiff_gpu(data, path: str, *,
12451264
comp_tag = _compression_tag(compression)
12461265
pred_val = 2 if predictor else 1
12471266

1248-
# GPU compress
1249-
compressed_tiles = gpu_compress_tiles(
1250-
arr, tile_size, tile_size, width, height,
1251-
comp_tag, pred_val, np_dtype, samples)
1252-
1253-
# Build offset/bytecount lists
1254-
rel_offsets = []
1255-
byte_counts = []
1256-
offset = 0
1257-
for tile in compressed_tiles:
1258-
rel_offsets.append(offset)
1259-
byte_counts.append(len(tile))
1260-
offset += len(tile)
1261-
1262-
# Assemble TIFF on CPU (only metadata + compressed bytes)
1263-
# _assemble_tiff needs an array in parts[0] to detect samples_per_pixel
1264-
shape_stub = np.empty((1, 1, samples) if samples > 1 else (1, 1), dtype=np_dtype)
1265-
parts = [(shape_stub, width, height, rel_offsets, byte_counts, compressed_tiles)]
1267+
def _gpu_compress_to_part(gpu_arr, w, h, spp):
1268+
"""Compress a GPU array into a (stub, w, h, offsets, counts, tiles) tuple."""
1269+
compressed = gpu_compress_tiles(
1270+
gpu_arr, tile_size, tile_size, w, h,
1271+
comp_tag, pred_val, np_dtype, spp)
1272+
rel_off = []
1273+
bc = []
1274+
off = 0
1275+
for tile in compressed:
1276+
rel_off.append(off)
1277+
bc.append(len(tile))
1278+
off += len(tile)
1279+
stub = np.empty((1, 1, spp) if spp > 1 else (1, 1), dtype=np_dtype)
1280+
return (stub, w, h, rel_off, bc, compressed)
1281+
1282+
# Full resolution
1283+
parts = [_gpu_compress_to_part(arr, width, height, samples)]
1284+
1285+
# Overview generation
1286+
if cog:
1287+
if overview_levels is None:
1288+
overview_levels = []
1289+
oh, ow = height, width
1290+
while oh > tile_size and ow > tile_size:
1291+
oh //= 2
1292+
ow //= 2
1293+
if oh > 0 and ow > 0:
1294+
overview_levels.append(len(overview_levels) + 1)
1295+
1296+
current = arr
1297+
for _ in overview_levels:
1298+
current = make_overview_gpu(current, method=overview_resampling)
1299+
oh, ow = current.shape[:2]
1300+
parts.append(_gpu_compress_to_part(current, ow, oh, samples))
12661301

12671302
file_bytes = _assemble_tiff(
12681303
width, height, np_dtype, comp_tag, predictor, True, tile_size,
1269-
parts, geo_transform, epsg, nodata, is_cog=False,
1304+
parts, geo_transform, epsg, nodata,
1305+
is_cog=(cog and len(parts) > 1),
12701306
raster_type=raster_type)
12711307

12721308
_write_bytes(file_bytes, path)

xrspatial/geotiff/_gpu_decode.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2317,3 +2317,80 @@ def gpu_compress_tiles(d_image, tile_width, tile_height,
23172317
result.append(cpu_compress(tile_data, compression))
23182318

23192319
return result
2320+
2321+
2322+
# ---------------------------------------------------------------------------
2323+
# GPU overview (pyramid) generation
2324+
# ---------------------------------------------------------------------------
2325+
2326+
GPU_OVERVIEW_METHODS = ('mean', 'nearest', 'min', 'max', 'median', 'mode')
2327+
2328+
2329+
def _block_reduce_2d_gpu(arr2d, method):
2330+
"""2x block-reduce a single 2D CuPy plane using *method*."""
2331+
import cupy
2332+
2333+
h, w = arr2d.shape
2334+
h2 = (h // 2) * 2
2335+
w2 = (w // 2) * 2
2336+
cropped = arr2d[:h2, :w2]
2337+
oh, ow = h2 // 2, w2 // 2
2338+
2339+
if method == 'nearest':
2340+
return cropped[::2, ::2].copy()
2341+
2342+
if method == 'mode':
2343+
# Mode is expensive on GPU; fall back to CPU
2344+
cpu_arr = arr2d.get()
2345+
from ._writer import _block_reduce_2d
2346+
cpu_result = _block_reduce_2d(cpu_arr, 'mode')
2347+
return cupy.asarray(cpu_result)
2348+
2349+
# Block reshape for mean/min/max/median
2350+
if arr2d.dtype.kind == 'f':
2351+
blocks = cropped.reshape(oh, 2, ow, 2)
2352+
else:
2353+
blocks = cropped.astype(cupy.float64).reshape(oh, 2, ow, 2)
2354+
2355+
if method == 'mean':
2356+
result = cupy.nanmean(blocks, axis=(1, 3))
2357+
elif method == 'min':
2358+
result = cupy.nanmin(blocks, axis=(1, 3))
2359+
elif method == 'max':
2360+
result = cupy.nanmax(blocks, axis=(1, 3))
2361+
elif method == 'median':
2362+
flat = blocks.transpose(0, 2, 1, 3).reshape(oh, ow, 4)
2363+
result = cupy.nanmedian(flat, axis=2)
2364+
else:
2365+
raise ValueError(
2366+
f"Unknown GPU overview resampling method: {method!r}. "
2367+
f"Use one of: {GPU_OVERVIEW_METHODS}")
2368+
2369+
if arr2d.dtype.kind != 'f':
2370+
return cupy.around(result).astype(arr2d.dtype)
2371+
return result.astype(arr2d.dtype)
2372+
2373+
2374+
def make_overview_gpu(arr, method='mean'):
2375+
"""Generate a 2x decimated overview on GPU.
2376+
2377+
Parameters
2378+
----------
2379+
arr : cupy.ndarray
2380+
2D or 3D (height, width, bands) array on GPU.
2381+
method : str
2382+
Resampling method: 'mean', 'nearest', 'min', 'max', 'median',
2383+
or 'mode'.
2384+
2385+
Returns
2386+
-------
2387+
cupy.ndarray
2388+
Half-resolution array on GPU.
2389+
"""
2390+
import cupy
2391+
2392+
if arr.ndim == 3:
2393+
bands = [_block_reduce_2d_gpu(arr[:, :, b], method)
2394+
for b in range(arr.shape[2])]
2395+
return cupy.stack(bands, axis=2)
2396+
return _block_reduce_2d_gpu(arr, method)

0 commit comments

Comments
 (0)