@@ -827,6 +827,51 @@ def _parse_max_memory(max_memory):
827827 return int (s )
828828
829829
830+ def _process_tile_batch (batch , source_data , src_bounds , src_shape , y_desc ,
831+ src_wkt , tgt_wkt , resampling , nodata , precision ,
832+ max_memory_bytes , tile_mem ):
833+ """Process a batch of tiles within a single worker.
834+
835+ Uses ThreadPoolExecutor for intra-worker parallelism (Numba
836+ releases the GIL). Memory bounded by max_memory_bytes.
837+
838+ Returns list of (row_offset, col_offset, tile_data) tuples.
839+ """
840+ max_concurrent = max (1 , max_memory_bytes // max (tile_mem , 1 ))
841+
842+ def _do_one (job ):
843+ _ , _ , rchunk , cchunk , cb = job
844+ return _reproject_chunk_numpy (
845+ source_data ,
846+ src_bounds , src_shape , y_desc ,
847+ src_wkt , tgt_wkt ,
848+ cb , (rchunk , cchunk ),
849+ resampling , nodata , precision ,
850+ )
851+
852+ results = []
853+ if max_concurrent >= 2 and len (batch ) > 1 :
854+ import os
855+ from concurrent .futures import ThreadPoolExecutor
856+ n_threads = min (max_concurrent , len (batch ), os .cpu_count () or 4 )
857+ with ThreadPoolExecutor (max_workers = n_threads ) as pool :
858+ for sub_start in range (0 , len (batch ), n_threads ):
859+ sub = batch [sub_start :sub_start + n_threads ]
860+ tiles = list (pool .map (_do_one , sub ))
861+ for job , tile in zip (sub , tiles ):
862+ ro , co , rchunk , cchunk , _ = job
863+ results .append ((ro , co , tile ))
864+ del tiles
865+ else :
866+ for job in batch :
867+ ro , co , rchunk , cchunk , _ = job
868+ tile = _do_one (job )
869+ results .append ((ro , co , tile ))
870+ del tile
871+
872+ return results
873+
874+
830875def _reproject_streaming (
831876 raster , src_bounds , src_shape , y_desc ,
832877 src_wkt , tgt_wkt ,
@@ -836,24 +881,22 @@ def _reproject_streaming(
836881):
837882 """Streaming reproject for datasets too large for dask's graph.
838883
839- Uses a ThreadPoolExecutor with bounded concurrency based on
840- max_memory. Numba kernels release the GIL, so threads give
841- real parallelism. Each worker processes one output tile:
842- compute coordinates, read source window, resample.
884+ Two modes:
885+ 1. **Local** (no dask.distributed): ThreadPoolExecutor within one
886+ process, bounded by max_memory.
887+ 2. **Distributed** (dask.distributed active): creates a dask.bag
888+ with one partition per worker, each partition processes its
889+ tile batch using threads. Graph size: O(n_workers), not
890+ O(n_tiles).
843891
844- Memory usage: max_memory_bytes total across all concurrent tiles .
892+ Memory usage per worker: bounded by max_memory .
845893 """
846894 if isinstance (tile_size , int ):
847895 tile_size = (tile_size , tile_size )
848896
849897 row_chunks , col_chunks = _compute_chunk_layout (out_shape , tile_size )
850- result = np .full (out_shape , nodata , dtype = np .float64 )
851898
852- # Compute how many tiles can run concurrently within memory budget.
853- # Each tile needs: output (tile_size^2 * 8) + source window (~same)
854- # + coordinates (tile_size^2 * 8 * 2)
855899 tile_mem = tile_size [0 ] * tile_size [1 ] * 8 * 4 # ~4 arrays per tile
856- max_concurrent = max (1 , max_memory_bytes // tile_mem )
857900
858901 # Build tile job list
859902 jobs = []
@@ -870,36 +913,55 @@ def _reproject_streaming(
870913 col_offset += cchunk
871914 row_offset += rchunk
872915
873- def _process_tile (job ):
874- _ , _ , rchunk , cchunk , cb = job
875- return _reproject_chunk_numpy (
876- raster .data ,
877- src_bounds , src_shape , y_desc ,
878- src_wkt , tgt_wkt ,
879- cb , (rchunk , cchunk ),
880- resampling , nodata , precision ,
916+ # Check if dask.distributed is active
917+ _use_distributed = False
918+ try :
919+ from dask .distributed import get_client
920+ client = get_client ()
921+ n_distributed_workers = len (client .scheduler_info ()['workers' ])
922+ if n_distributed_workers > 0 :
923+ _use_distributed = True
924+ except (ImportError , ValueError ):
925+ pass
926+
927+ if _use_distributed and len (jobs ) > n_distributed_workers :
928+ # Distributed: partition tiles across workers via dask.bag
929+ import dask .bag as db
930+
931+ # Split jobs into N partitions (one per worker)
932+ n_parts = min (n_distributed_workers , len (jobs ))
933+ batch_size = math .ceil (len (jobs ) / n_parts )
934+ batches = [jobs [i :i + batch_size ] for i in range (0 , len (jobs ), batch_size )]
935+
936+ # Create bag and map the batch processor
937+ bag = db .from_sequence (batches , npartitions = len (batches ))
938+ results_bag = bag .map (
939+ _process_tile_batch ,
940+ source_data = raster .data ,
941+ src_bounds = src_bounds , src_shape = src_shape , y_desc = y_desc ,
942+ src_wkt = src_wkt , tgt_wkt = tgt_wkt ,
943+ resampling = resampling , nodata = nodata , precision = precision ,
944+ max_memory_bytes = max_memory_bytes , tile_mem = tile_mem ,
881945 )
882946
883- if max_concurrent >= 2 and len (jobs ) > 1 :
884- import os
885- from concurrent .futures import ThreadPoolExecutor
886- n_workers = min (max_concurrent , len (jobs ), os .cpu_count () or 4 )
887- with ThreadPoolExecutor (max_workers = n_workers ) as pool :
888- # Process in batches to bound memory
889- for batch_start in range (0 , len (jobs ), n_workers ):
890- batch = jobs [batch_start :batch_start + n_workers ]
891- tiles = list (pool .map (_process_tile , batch ))
892- for job , tile in zip (batch , tiles ):
893- ro , co , rchunk , cchunk , _ = job
894- result [ro :ro + rchunk , co :co + cchunk ] = tile
895- del tiles
896- else :
897- # Sequential fallback
898- for job in jobs :
899- ro , co , rchunk , cchunk , _ = job
900- tile = _process_tile (job )
901- result [ro :ro + rchunk , co :co + cchunk ] = tile
902- del tile
947+ # Compute all partitions and assemble result
948+ result = np .full (out_shape , nodata , dtype = np .float64 )
949+ for batch_results in results_bag .compute ():
950+ for ro , co , tile in batch_results :
951+ result [ro :ro + tile .shape [0 ], co :co + tile .shape [1 ]] = tile
952+ return result
953+
954+ # Local: ThreadPoolExecutor within one process
955+ result = np .full (out_shape , nodata , dtype = np .float64 )
956+ batch_results = _process_tile_batch (
957+ jobs , raster .data ,
958+ src_bounds , src_shape , y_desc ,
959+ src_wkt , tgt_wkt ,
960+ resampling , nodata , precision ,
961+ max_memory_bytes , tile_mem ,
962+ )
963+ for ro , co , tile in batch_results :
964+ result [ro :ro + tile .shape [0 ], co :co + tile .shape [1 ]] = tile
903965
904966 return result
905967
0 commit comments