Skip to content

Commit 6316ef4

Browse files
authored
Cap dask graph size in read_geotiff_dask and batch adler32 transfers (#1211)
read_geotiff_dask built one delayed task per chunk with no upper bound. For very large files at small chunk sizes the Python graph itself OOMs the driver before any pixel read runs (30TB at chunks=256 would produce ~125M chunks, ~500M tasks, ~500GB graph on the host). Cap total chunks at 1,000,000 and auto-scale the requested chunks size upward, emitting a UserWarning so callers know their request was adjusted. _nvcomp_batch_compress on the deflate path copied every uncompressed tile GPU->CPU one at a time with .get().tobytes() purely to compute the zlib adler32 trailer. Each per-tile .get() is a sync point on the default stream. Batch all tiles into a single contiguous device buffer, transfer once, then compute adler32 from a host memoryview slice per tile.
1 parent d05d9b7 commit 6316ef4

2 files changed

Lines changed: 35 additions & 5 deletions

File tree

xrspatial/geotiff/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,27 @@ def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512,
937937
else:
938938
ch_h, ch_w = chunks
939939

940+
# Graph-size guard. Each chunk becomes a delayed task whose Python graph
941+
# entry retains ~1KB. At very large chunk counts the graph itself OOMs
942+
# the driver before any read executes (30TB at chunks=256 => ~500M tasks
943+
# => ~500GB graph on host). Auto-scale chunks up to cap total task count.
944+
_MAX_DASK_CHUNKS = 1_000_000
945+
n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w)
946+
if n_chunks > _MAX_DASK_CHUNKS:
947+
import math
948+
scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS)
949+
new_ch_h = int(math.ceil(ch_h * scale))
950+
new_ch_w = int(math.ceil(ch_w * scale))
951+
import warnings
952+
warnings.warn(
953+
f"read_geotiff_dask: requested chunks=({ch_h}, {ch_w}) on a "
954+
f"{full_h}x{full_w} image would produce {n_chunks} dask tasks, "
955+
f"exceeding the {_MAX_DASK_CHUNKS}-task cap. Auto-scaling to "
956+
f"chunks=({new_ch_h}, {new_ch_w}).",
957+
stacklevel=2,
958+
)
959+
ch_h, ch_w = new_ch_h, new_ch_w
960+
940961
# Build dask array from delayed windowed reads
941962
rows = list(range(0, full_h, ch_h))
942963
cols = list(range(0, full_w, ch_w))

xrspatial/geotiff/_gpu_decode.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,15 +1813,24 @@ class _DeflateCompOpts(ctypes.Structure):
18131813
return None
18141814

18151815
# For deflate, compute adler32 checksums from uncompressed tiles
1816-
# before reading compressed data (need the originals)
1816+
# before reading compressed data (need the originals).
1817+
# Batch the GPU->CPU transfer so all tiles move in a single DMA
1818+
# instead of one .get() per tile (which serializes on the default
1819+
# stream and is the dominant cost on the deflate path).
18171820
adler_checksums = None
18181821
if compression in (8, 32946):
18191822
import zlib
18201823
import struct
1821-
adler_checksums = []
1822-
for i in range(n_tiles):
1823-
uncomp = d_tile_bufs[i].get().tobytes()
1824-
adler_checksums.append(zlib.adler32(uncomp))
1824+
adler_checksums = [None] * n_tiles
1825+
if n_tiles > 0:
1826+
d_contig = cupy.empty(n_tiles * tile_bytes, dtype=cupy.uint8)
1827+
for i in range(n_tiles):
1828+
d_contig[i * tile_bytes:(i + 1) * tile_bytes] = \
1829+
d_tile_bufs[i][:tile_bytes]
1830+
host_view = memoryview(d_contig.get())
1831+
for i in range(n_tiles):
1832+
adler_checksums[i] = zlib.adler32(
1833+
host_view[i * tile_bytes:(i + 1) * tile_bytes])
18251834

18261835
# Read compressed sizes and data back to CPU
18271836
comp_sizes = d_comp_sizes.get().astype(int)

0 commit comments

Comments
 (0)