Skip to content

Commit 443de96

Browse files
committed
Distributed streaming via dask.bag for multi-worker clusters (#1045)
When dask.distributed is active, the streaming path uses dask.bag to distribute tile batches across workers instead of processing everything in one process: Local (no cluster): ThreadPoolExecutor within one process, max_memory bounded Distributed (dask cluster active): 1. Partition 2M tiles into N batches (one per worker) 2. dask.bag.from_sequence(batches, npartitions=N) 3. bag.map(process_batch) -- each worker gets its batch 4. Within each worker, ThreadPoolExecutor for intra-worker parallelism (Numba releases GIL) 5. Assemble results Graph size comparison for 30TB: Old dask.array approach: 1,968,409 nodes (1.9GB graph, OOM) New dask.bag approach: 4-64 nodes (one per worker) Each worker's memory bounded by max_memory parameter. Auto-detects distributed via get_client().
1 parent d835c16 commit 443de96

File tree

1 file changed

+100
-38
lines changed

1 file changed

+100
-38
lines changed

xrspatial/reproject/__init__.py

Lines changed: 100 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,51 @@ def _parse_max_memory(max_memory):
827827
return int(s)
828828

829829

830+
def _process_tile_batch(batch, source_data, src_bounds, src_shape, y_desc,
831+
src_wkt, tgt_wkt, resampling, nodata, precision,
832+
max_memory_bytes, tile_mem):
833+
"""Process a batch of tiles within a single worker.
834+
835+
Uses ThreadPoolExecutor for intra-worker parallelism (Numba
836+
releases the GIL). Memory bounded by max_memory_bytes.
837+
838+
Returns list of (row_offset, col_offset, tile_data) tuples.
839+
"""
840+
max_concurrent = max(1, max_memory_bytes // max(tile_mem, 1))
841+
842+
def _do_one(job):
843+
_, _, rchunk, cchunk, cb = job
844+
return _reproject_chunk_numpy(
845+
source_data,
846+
src_bounds, src_shape, y_desc,
847+
src_wkt, tgt_wkt,
848+
cb, (rchunk, cchunk),
849+
resampling, nodata, precision,
850+
)
851+
852+
results = []
853+
if max_concurrent >= 2 and len(batch) > 1:
854+
import os
855+
from concurrent.futures import ThreadPoolExecutor
856+
n_threads = min(max_concurrent, len(batch), os.cpu_count() or 4)
857+
with ThreadPoolExecutor(max_workers=n_threads) as pool:
858+
for sub_start in range(0, len(batch), n_threads):
859+
sub = batch[sub_start:sub_start + n_threads]
860+
tiles = list(pool.map(_do_one, sub))
861+
for job, tile in zip(sub, tiles):
862+
ro, co, rchunk, cchunk, _ = job
863+
results.append((ro, co, tile))
864+
del tiles
865+
else:
866+
for job in batch:
867+
ro, co, rchunk, cchunk, _ = job
868+
tile = _do_one(job)
869+
results.append((ro, co, tile))
870+
del tile
871+
872+
return results
873+
874+
830875
def _reproject_streaming(
831876
raster, src_bounds, src_shape, y_desc,
832877
src_wkt, tgt_wkt,
@@ -836,24 +881,22 @@ def _reproject_streaming(
836881
):
837882
"""Streaming reproject for datasets too large for dask's graph.
838883
839-
Uses a ThreadPoolExecutor with bounded concurrency based on
840-
max_memory. Numba kernels release the GIL, so threads give
841-
real parallelism. Each worker processes one output tile:
842-
compute coordinates, read source window, resample.
884+
Two modes:
885+
1. **Local** (no dask.distributed): ThreadPoolExecutor within one
886+
process, bounded by max_memory.
887+
2. **Distributed** (dask.distributed active): creates a dask.bag
888+
with one partition per worker, each partition processes its
889+
tile batch using threads. Graph size: O(n_workers), not
890+
O(n_tiles).
843891
844-
Memory usage: max_memory_bytes total across all concurrent tiles.
892+
Memory usage per worker: bounded by max_memory.
845893
"""
846894
if isinstance(tile_size, int):
847895
tile_size = (tile_size, tile_size)
848896

849897
row_chunks, col_chunks = _compute_chunk_layout(out_shape, tile_size)
850-
result = np.full(out_shape, nodata, dtype=np.float64)
851898

852-
# Compute how many tiles can run concurrently within memory budget.
853-
# Each tile needs: output (tile_size^2 * 8) + source window (~same)
854-
# + coordinates (tile_size^2 * 8 * 2)
855899
tile_mem = tile_size[0] * tile_size[1] * 8 * 4 # ~4 arrays per tile
856-
max_concurrent = max(1, max_memory_bytes // tile_mem)
857900

858901
# Build tile job list
859902
jobs = []
@@ -870,36 +913,55 @@ def _reproject_streaming(
870913
col_offset += cchunk
871914
row_offset += rchunk
872915

873-
def _process_tile(job):
874-
_, _, rchunk, cchunk, cb = job
875-
return _reproject_chunk_numpy(
876-
raster.data,
877-
src_bounds, src_shape, y_desc,
878-
src_wkt, tgt_wkt,
879-
cb, (rchunk, cchunk),
880-
resampling, nodata, precision,
916+
# Check if dask.distributed is active
917+
_use_distributed = False
918+
try:
919+
from dask.distributed import get_client
920+
client = get_client()
921+
n_distributed_workers = len(client.scheduler_info()['workers'])
922+
if n_distributed_workers > 0:
923+
_use_distributed = True
924+
except (ImportError, ValueError):
925+
pass
926+
927+
if _use_distributed and len(jobs) > n_distributed_workers:
928+
# Distributed: partition tiles across workers via dask.bag
929+
import dask.bag as db
930+
931+
# Split jobs into N partitions (one per worker)
932+
n_parts = min(n_distributed_workers, len(jobs))
933+
batch_size = math.ceil(len(jobs) / n_parts)
934+
batches = [jobs[i:i + batch_size] for i in range(0, len(jobs), batch_size)]
935+
936+
# Create bag and map the batch processor
937+
bag = db.from_sequence(batches, npartitions=len(batches))
938+
results_bag = bag.map(
939+
_process_tile_batch,
940+
source_data=raster.data,
941+
src_bounds=src_bounds, src_shape=src_shape, y_desc=y_desc,
942+
src_wkt=src_wkt, tgt_wkt=tgt_wkt,
943+
resampling=resampling, nodata=nodata, precision=precision,
944+
max_memory_bytes=max_memory_bytes, tile_mem=tile_mem,
881945
)
882946

883-
if max_concurrent >= 2 and len(jobs) > 1:
884-
import os
885-
from concurrent.futures import ThreadPoolExecutor
886-
n_workers = min(max_concurrent, len(jobs), os.cpu_count() or 4)
887-
with ThreadPoolExecutor(max_workers=n_workers) as pool:
888-
# Process in batches to bound memory
889-
for batch_start in range(0, len(jobs), n_workers):
890-
batch = jobs[batch_start:batch_start + n_workers]
891-
tiles = list(pool.map(_process_tile, batch))
892-
for job, tile in zip(batch, tiles):
893-
ro, co, rchunk, cchunk, _ = job
894-
result[ro:ro + rchunk, co:co + cchunk] = tile
895-
del tiles
896-
else:
897-
# Sequential fallback
898-
for job in jobs:
899-
ro, co, rchunk, cchunk, _ = job
900-
tile = _process_tile(job)
901-
result[ro:ro + rchunk, co:co + cchunk] = tile
902-
del tile
947+
# Compute all partitions and assemble result
948+
result = np.full(out_shape, nodata, dtype=np.float64)
949+
for batch_results in results_bag.compute():
950+
for ro, co, tile in batch_results:
951+
result[ro:ro + tile.shape[0], co:co + tile.shape[1]] = tile
952+
return result
953+
954+
# Local: ThreadPoolExecutor within one process
955+
result = np.full(out_shape, nodata, dtype=np.float64)
956+
batch_results = _process_tile_batch(
957+
jobs, raster.data,
958+
src_bounds, src_shape, y_desc,
959+
src_wkt, tgt_wkt,
960+
resampling, nodata, precision,
961+
max_memory_bytes, tile_mem,
962+
)
963+
for ro, co, tile in batch_results:
964+
result[ro:ro + tile.shape[0], co:co + tile.shape[1]] = tile
903965

904966
return result
905967

0 commit comments

Comments
 (0)