@@ -574,37 +574,31 @@ def reproject(
574574 else :
575575 is_cupy = is_cupy_array (data )
576576
577- # For very large datasets, estimate whether a dask graph would fit
578- # in memory. Each dask task uses ~1KB of graph metadata. If the
579- # graph itself would exceed available memory, use a streaming
580- # approach instead of dask (process tiles sequentially, no graph) .
577+ # For large in-memory datasets, wrap in dask for chunked processing.
578+ # map_blocks generates an O(1) HighLevelGraph (single blockwise layer)
579+ # so graph metadata is no longer a concern -- the streaming fallback
580+ # is only needed when dask itself is unavailable .
581581 _use_streaming = False
582582 if not is_dask and not is_cupy :
583583 nbytes = src_shape [0 ] * src_shape [1 ] * data .dtype .itemsize
584584 if data .ndim == 3 :
585585 nbytes *= data .shape [2 ]
586586 _OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB
587587 if nbytes > _OOM_THRESHOLD :
588- # Estimate graph size for the output
589588 cs = chunk_size or 2048
590589 if isinstance (cs , int ):
591590 cs = (cs , cs )
592- n_out_chunks = (math .ceil (out_shape [0 ] / cs [0 ])
593- * math .ceil (out_shape [1 ] / cs [1 ]))
594- graph_bytes = n_out_chunks * 1024 # ~1KB per task
595-
596- if graph_bytes > 1024 * 1024 * 1024 : # > 1GB graph
597- # Graph too large for dask -- use streaming
598- _use_streaming = True
599- else :
600- # Graph fits -- use dask with large chunks
591+ try :
601592 import dask .array as _da
602593 data = _da .from_array (data , chunks = cs )
603594 raster = xr .DataArray (
604595 data , dims = raster .dims , coords = raster .coords ,
605596 name = raster .name , attrs = raster .attrs ,
606597 )
607598 is_dask = True
599+ except ImportError :
600+ # dask not available -- fall back to streaming
601+ _use_streaming = True
608602
609603 # Serialize CRS for pickle safety
610604 src_wkt = src_crs .to_wkt ()
@@ -1125,51 +1119,121 @@ def _reproject_dask_cupy(
11251119 return result
11261120
11271121
1122+ def _source_footprint_in_target (src_bounds , src_wkt , tgt_wkt ):
1123+ """Compute an approximate bounding box of the source raster in target CRS.
1124+
1125+ Transforms corners and edge midpoints (12 points) to handle non-linear
1126+ projections. Returns ``(left, bottom, right, top)`` in target CRS, or
1127+ *None* if the transform fails (e.g. out-of-domain).
1128+ """
1129+ try :
1130+ from ._crs_utils import _require_pyproj
1131+ pyproj = _require_pyproj ()
1132+ src_crs = pyproj .CRS (src_wkt )
1133+ tgt_crs = pyproj .CRS (tgt_wkt )
1134+ transformer = pyproj .Transformer .from_crs (
1135+ src_crs , tgt_crs , always_xy = True
1136+ )
1137+ except Exception :
1138+ return None
1139+
1140+ sl , sb , sr , st = src_bounds
1141+ mx = (sl + sr ) / 2
1142+ my = (sb + st ) / 2
1143+ xs = [sl , mx , sr , sl , mx , sr , sl , mx , sr , sl , sr , mx ]
1144+ ys = [sb , sb , sb , my , my , my , st , st , st , mx , mx , sb ]
1145+ try :
1146+ tx , ty = transformer .transform (xs , ys )
1147+ tx = [v for v in tx if np .isfinite (v )]
1148+ ty = [v for v in ty if np .isfinite (v )]
1149+ if not tx or not ty :
1150+ return None
1151+ return (min (tx ), min (ty ), max (tx ), max (ty ))
1152+ except Exception :
1153+ return None
1154+
1155+
1156+ def _bounds_overlap (a , b ):
1157+ """Return True if bounding boxes *a* and *b* overlap."""
1158+ return a [0 ] < b [2 ] and a [2 ] > b [0 ] and a [1 ] < b [3 ] and a [3 ] > b [1 ]
1159+
1160+
1161+ def _reproject_block_adapter (
1162+ block , block_info , source_data ,
1163+ src_bounds , src_shape , y_desc ,
1164+ src_wkt , tgt_wkt ,
1165+ out_bounds , out_shape ,
1166+ resampling , nodata , precision ,
1167+ is_cupy , src_footprint_tgt ,
1168+ ):
1169+ """``map_blocks`` adapter for reprojection.
1170+
1171+ Derives chunk bounds from *block_info* and delegates to the
1172+ per-chunk worker.
1173+ """
1174+ info = block_info [0 ]
1175+ (row_start , row_end ), (col_start , col_end ) = info ['array-location' ]
1176+ chunk_shape = (row_end - row_start , col_end - col_start )
1177+ cb = _chunk_bounds (out_bounds , out_shape ,
1178+ row_start , row_end , col_start , col_end )
1179+
1180+ # Skip chunks that don't overlap the source footprint
1181+ if src_footprint_tgt is not None and not _bounds_overlap (cb , src_footprint_tgt ):
1182+ return np .full (chunk_shape , nodata , dtype = np .float64 )
1183+
1184+ chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy
1185+ return chunk_fn (
1186+ source_data , src_bounds , src_shape , y_desc ,
1187+ src_wkt , tgt_wkt ,
1188+ cb , chunk_shape ,
1189+ resampling , nodata , precision ,
1190+ )
1191+
1192+
11281193def _reproject_dask (
11291194 raster , src_bounds , src_shape , y_desc ,
11301195 src_wkt , tgt_wkt ,
11311196 out_bounds , out_shape ,
11321197 resampling , nodata , precision ,
11331198 chunk_size , is_cupy ,
11341199):
1135- """Dask+NumPy backend: build output as ``da.block`` of delayed chunks."""
1136- import dask
1200+ """Dask+NumPy backend: ``map_blocks`` over a template array.
1201+
1202+ Uses a single ``blockwise`` layer in the HighLevelGraph instead of
1203+ O(N) ``dask.delayed`` nodes, keeping graph metadata O(1).
1204+ """
11371205 import dask .array as da
11381206
11391207 row_chunks , col_chunks = _compute_chunk_layout (out_shape , chunk_size )
1140- n_row = len (row_chunks )
1141- n_col = len (col_chunks )
1142-
1143- chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy
1144- dtype = np .float64
11451208
1146- blocks = [[None ] * n_col for _ in range (n_row )]
1209+ # Precompute source footprint in target CRS for empty-chunk skipping
1210+ src_footprint_tgt = _source_footprint_in_target (
1211+ src_bounds , src_wkt , tgt_wkt
1212+ )
11471213
1148- row_offset = 0
1149- for i in range (n_row ):
1150- col_offset = 0
1151- for j in range (n_col ):
1152- rchunk = row_chunks [i ]
1153- cchunk = col_chunks [j ]
1154- cb = _chunk_bounds (
1155- out_bounds , out_shape ,
1156- row_offset , row_offset + rchunk ,
1157- col_offset , col_offset + cchunk ,
1158- )
1159- delayed_chunk = dask .delayed (chunk_fn )(
1160- raster .data ,
1161- src_bounds , src_shape , y_desc ,
1162- src_wkt , tgt_wkt ,
1163- cb , (rchunk , cchunk ),
1164- resampling , nodata , precision ,
1165- )
1166- blocks [i ][j ] = da .from_delayed (
1167- delayed_chunk , shape = (rchunk , cchunk ), dtype = dtype
1168- )
1169- col_offset += cchunk
1170- row_offset += rchunk
1214+ template = da .empty (
1215+ out_shape , dtype = np .float64 , chunks = (row_chunks , col_chunks )
1216+ )
11711217
1172- return da .block (blocks )
1218+ return da .map_blocks (
1219+ _reproject_block_adapter ,
1220+ template ,
1221+ source_data = raster .data ,
1222+ src_bounds = src_bounds ,
1223+ src_shape = src_shape ,
1224+ y_desc = y_desc ,
1225+ src_wkt = src_wkt ,
1226+ tgt_wkt = tgt_wkt ,
1227+ out_bounds = out_bounds ,
1228+ out_shape = out_shape ,
1229+ resampling = resampling ,
1230+ nodata = nodata ,
1231+ precision = precision ,
1232+ is_cupy = is_cupy ,
1233+ src_footprint_tgt = src_footprint_tgt ,
1234+ dtype = np .float64 ,
1235+ meta = np .array ((), dtype = np .float64 ),
1236+ )
11731237
11741238
11751239# ---------------------------------------------------------------------------
@@ -1434,37 +1498,49 @@ def _merge_inmemory(
14341498 return _merge_arrays_numpy (arrays , nodata , strategy )
14351499
14361500
1437- def _merge_chunk_worker (
1501+ def _merge_block_adapter (
1502+ block , block_info ,
14381503 raster_data_list , src_bounds_list , src_shape_list , y_desc_list ,
14391504 src_wkt_list , tgt_wkt ,
1440- chunk_bounds_tuple , chunk_shape ,
1505+ out_bounds , out_shape ,
14411506 resampling , nodata , strategy , precision ,
1507+ src_footprints_tgt ,
14421508):
1443- """Worker for a single merge chunk."""
1509+ """``map_blocks`` adapter for merge."""
1510+ info = block_info [0 ]
1511+ (row_start , row_end ), (col_start , col_end ) = info ['array-location' ]
1512+ chunk_shape = (row_end - row_start , col_end - col_start )
1513+ cb = _chunk_bounds (out_bounds , out_shape ,
1514+ row_start , row_end , col_start , col_end )
1515+
1516+ # Only reproject rasters whose footprint overlaps this chunk
14441517 arrays = []
14451518 for i in range (len (raster_data_list )):
1519+ if (src_footprints_tgt [i ] is not None
1520+ and not _bounds_overlap (cb , src_footprints_tgt [i ])):
1521+ continue
14461522 reprojected = _reproject_chunk_numpy (
14471523 raster_data_list [i ],
14481524 src_bounds_list [i ], src_shape_list [i ], y_desc_list [i ],
14491525 src_wkt_list [i ], tgt_wkt ,
1450- chunk_bounds_tuple , chunk_shape ,
1526+ cb , chunk_shape ,
14511527 resampling , nodata , precision ,
14521528 )
14531529 arrays .append (reprojected )
1530+
1531+ if not arrays :
1532+ return np .full (chunk_shape , nodata , dtype = np .float64 )
14541533 return _merge_arrays_numpy (arrays , nodata , strategy )
14551534
14561535
14571536def _merge_dask (
14581537 raster_infos , tgt_wkt , out_bounds , out_shape ,
14591538 resampling , nodata , strategy , chunk_size ,
14601539):
1461- """Dask merge backend."""
1462- import dask
1540+ """Dask merge backend using ``map_blocks``."""
14631541 import dask .array as da
14641542
14651543 row_chunks , col_chunks = _compute_chunk_layout (out_shape , chunk_size )
1466- n_row = len (row_chunks )
1467- n_col = len (col_chunks )
14681544
14691545 # Prepare lists for the worker
14701546 data_list = [info ['raster' ].data for info in raster_infos ]
@@ -1473,30 +1549,32 @@ def _merge_dask(
14731549 ydesc_list = [info ['y_desc' ] for info in raster_infos ]
14741550 wkt_list = [info ['src_wkt' ] for info in raster_infos ]
14751551
1476- dtype = np .float64
1477- blocks = [[None ] * n_col for _ in range (n_row )]
1552+ # Precompute source footprints in target CRS
1553+ footprints = [
1554+ _source_footprint_in_target (bounds_list [i ], wkt_list [i ], tgt_wkt )
1555+ for i in range (len (raster_infos ))
1556+ ]
14781557
1479- row_offset = 0
1480- for i in range (n_row ):
1481- col_offset = 0
1482- for j in range (n_col ):
1483- rchunk = row_chunks [i ]
1484- cchunk = col_chunks [j ]
1485- cb = _chunk_bounds (
1486- out_bounds , out_shape ,
1487- row_offset , row_offset + rchunk ,
1488- col_offset , col_offset + cchunk ,
1489- )
1490- delayed_chunk = dask .delayed (_merge_chunk_worker )(
1491- data_list , bounds_list , shape_list , ydesc_list ,
1492- wkt_list , tgt_wkt ,
1493- cb , (rchunk , cchunk ),
1494- resampling , nodata , strategy , 16 ,
1495- )
1496- blocks [i ][j ] = da .from_delayed (
1497- delayed_chunk , shape = (rchunk , cchunk ), dtype = dtype
1498- )
1499- col_offset += cchunk
1500- row_offset += rchunk
1558+ template = da .empty (
1559+ out_shape , dtype = np .float64 , chunks = (row_chunks , col_chunks )
1560+ )
15011561
1502- return da .block (blocks )
1562+ return da .map_blocks (
1563+ _merge_block_adapter ,
1564+ template ,
1565+ raster_data_list = data_list ,
1566+ src_bounds_list = bounds_list ,
1567+ src_shape_list = shape_list ,
1568+ y_desc_list = ydesc_list ,
1569+ src_wkt_list = wkt_list ,
1570+ tgt_wkt = tgt_wkt ,
1571+ out_bounds = out_bounds ,
1572+ out_shape = out_shape ,
1573+ resampling = resampling ,
1574+ nodata = nodata ,
1575+ strategy = strategy ,
1576+ precision = 16 ,
1577+ src_footprints_tgt = footprints ,
1578+ dtype = np .float64 ,
1579+ meta = np .array ((), dtype = np .float64 ),
1580+ )
0 commit comments