@@ -453,6 +453,7 @@ def reproject(
453453 transform_precision = 16 ,
454454 chunk_size = None ,
455455 name = None ,
456+ max_memory = None ,
456457 src_vertical_crs = None ,
457458 tgt_vertical_crs = None ,
458459):
@@ -488,6 +489,12 @@ def reproject(
488489 Output chunk size for dask. Defaults to 512.
489490 name : str or None
490491 Name for the output DataArray.
492+ max_memory : int or str or None
493+ Maximum memory budget for the reprojection working set.
494+ Accepts bytes (int) or human-readable strings like ``'4GB'``,
495+ ``'512MB'``. Controls how many output tiles are processed
496+ in parallel for large-dataset streaming mode. Default None
497+ uses 1GB. Has no effect for small datasets that fit in memory.
491498 src_vertical_crs : str or None
492499 Source vertical datum for height values. One of:
493500
@@ -610,6 +617,7 @@ def reproject(
610617 out_bounds , out_shape ,
611618 resampling , nd , transform_precision ,
612619 chunk_size or 2048 ,
620+ _parse_max_memory (max_memory ),
613621 )
614622 elif is_dask and is_cupy :
615623 result_data = _reproject_dask_cupy (
@@ -806,31 +814,49 @@ def _reproject_inmemory_cupy(
806814 )
807815
808816
817+ def _parse_max_memory (max_memory ):
818+ """Parse max_memory parameter to bytes. Accepts int, '4GB', '512MB'."""
819+ if max_memory is None :
820+ return 1024 * 1024 * 1024 # 1GB default
821+ if isinstance (max_memory , (int , float )):
822+ return int (max_memory )
823+ s = str (max_memory ).strip ().upper ()
824+ for suffix , factor in [('TB' , 1024 ** 4 ), ('GB' , 1024 ** 3 ), ('MB' , 1024 ** 2 ), ('KB' , 1024 )]:
825+ if s .endswith (suffix ):
826+ return int (float (s [:- len (suffix )]) * factor )
827+ return int (s )
828+
829+
809830def _reproject_streaming (
810831 raster , src_bounds , src_shape , y_desc ,
811832 src_wkt , tgt_wkt ,
812833 out_bounds , out_shape ,
813834 resampling , nodata , precision ,
814- tile_size ,
835+ tile_size , max_memory_bytes ,
815836):
816837 """Streaming reproject for datasets too large for dask's graph.
817838
818- Processes output tiles sequentially in a simple loop:
819- 1. For each output tile, compute source coordinates (Numba)
820- 2. Read only the needed source window from the (possibly mmap'd) source
821- 3. Resample and write the tile into the output array
822- 4. Free the tile before processing the next one
839+ Uses a ThreadPoolExecutor with bounded concurrency based on
840+ max_memory. Numba kernels release the GIL, so threads give
841+ real parallelism. Each worker processes one output tile:
842+ compute coordinates, read source window, resample.
823843
824- Memory usage is O(tile_size^2), not O(total_pixels). No dask graph
825- is created, so there's no graph-size overhead. The output is a numpy
826- array assembled tile by tile.
844+ Memory usage: max_memory_bytes total across all concurrent tiles.
827845 """
828846 if isinstance (tile_size , int ):
829847 tile_size = (tile_size , tile_size )
830848
831849 row_chunks , col_chunks = _compute_chunk_layout (out_shape , tile_size )
832850 result = np .full (out_shape , nodata , dtype = np .float64 )
833851
852+ # Compute how many tiles can run concurrently within memory budget.
853+ # Each tile needs: output (tile_size^2 * 8) + source window (~same)
854+ # + coordinates (tile_size^2 * 8 * 2)
855+ tile_mem = tile_size [0 ] * tile_size [1 ] * 8 * 4 # ~4 arrays per tile
856+ max_concurrent = max (1 , max_memory_bytes // tile_mem )
857+
858+ # Build tile job list
859+ jobs = []
834860 row_offset = 0
835861 for rchunk in row_chunks :
836862 col_offset = 0
@@ -840,19 +866,41 @@ def _reproject_streaming(
840866 row_offset , row_offset + rchunk ,
841867 col_offset , col_offset + cchunk ,
842868 )
843- tile = _reproject_chunk_numpy (
844- raster .data ,
845- src_bounds , src_shape , y_desc ,
846- src_wkt , tgt_wkt ,
847- cb , (rchunk , cchunk ),
848- resampling , nodata , precision ,
849- )
850- result [row_offset :row_offset + rchunk ,
851- col_offset :col_offset + cchunk ] = tile
852- del tile # free immediately
869+ jobs .append ((row_offset , col_offset , rchunk , cchunk , cb ))
853870 col_offset += cchunk
854871 row_offset += rchunk
855872
873+ def _process_tile (job ):
874+ _ , _ , rchunk , cchunk , cb = job
875+ return _reproject_chunk_numpy (
876+ raster .data ,
877+ src_bounds , src_shape , y_desc ,
878+ src_wkt , tgt_wkt ,
879+ cb , (rchunk , cchunk ),
880+ resampling , nodata , precision ,
881+ )
882+
883+ if max_concurrent >= 2 and len (jobs ) > 1 :
884+ import os
885+ from concurrent .futures import ThreadPoolExecutor
886+ n_workers = min (max_concurrent , len (jobs ), os .cpu_count () or 4 )
887+ with ThreadPoolExecutor (max_workers = n_workers ) as pool :
888+ # Process in batches to bound memory
889+ for batch_start in range (0 , len (jobs ), n_workers ):
890+ batch = jobs [batch_start :batch_start + n_workers ]
891+ tiles = list (pool .map (_process_tile , batch ))
892+ for job , tile in zip (batch , tiles ):
893+ ro , co , rchunk , cchunk , _ = job
894+ result [ro :ro + rchunk , co :co + cchunk ] = tile
895+ del tiles
896+ else :
897+ # Sequential fallback
898+ for job in jobs :
899+ ro , co , rchunk , cchunk , _ = job
900+ tile = _process_tile (job )
901+ result [ro :ro + rchunk , co :co + cchunk ] = tile
902+ del tile
903+
856904 return result
857905
858906
0 commit comments