|
9 | 9 | """ |
10 | 10 | from __future__ import annotations |
11 | 11 |
|
| 12 | +import math |
| 13 | + |
12 | 14 | import numpy as np |
13 | 15 | import xarray as xr |
14 | 16 |
|
@@ -565,31 +567,51 @@ def reproject( |
565 | 567 | else: |
566 | 568 | is_cupy = is_cupy_array(data) |
567 | 569 |
|
568 | | - # Auto-chunk large non-dask arrays to prevent OOM. |
569 | | - # A 30TB float32 raster would instantly OOM if we called .values. |
570 | | - # Threshold: 512MB (configurable via chunk_size). |
| 570 | + # For very large datasets, estimate whether a dask graph would fit |
| 571 | + # in memory. Each dask task uses ~1KB of graph metadata. If the |
| 572 | + # graph itself would exceed available memory, use a streaming |
| 573 | + # approach instead of dask (process tiles sequentially, no graph). |
| 574 | + _use_streaming = False |
571 | 575 | if not is_dask and not is_cupy: |
572 | 576 | nbytes = src_shape[0] * src_shape[1] * data.dtype.itemsize |
573 | 577 | if data.ndim == 3: |
574 | 578 | nbytes *= data.shape[2] |
575 | 579 | _OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB |
576 | 580 | if nbytes > _OOM_THRESHOLD: |
577 | | - import dask.array as _da |
578 | | - cs = chunk_size or 512 |
| 581 | + # Estimate graph size for the output |
| 582 | + cs = chunk_size or 2048 |
579 | 583 | if isinstance(cs, int): |
580 | 584 | cs = (cs, cs) |
581 | | - data = _da.from_array(data, chunks=cs) |
582 | | - raster = xr.DataArray( |
583 | | - data, dims=raster.dims, coords=raster.coords, |
584 | | - name=raster.name, attrs=raster.attrs, |
585 | | - ) |
586 | | - is_dask = True |
| 585 | + n_out_chunks = (math.ceil(out_shape[0] / cs[0]) |
| 586 | + * math.ceil(out_shape[1] / cs[1])) |
| 587 | + graph_bytes = n_out_chunks * 1024 # ~1KB per task |
| 588 | + |
| 589 | + if graph_bytes > 1024 * 1024 * 1024: # > 1GB graph |
| 590 | + # Graph too large for dask -- use streaming |
| 591 | + _use_streaming = True |
| 592 | + else: |
| 593 | + # Graph fits -- use dask with large chunks |
| 594 | + import dask.array as _da |
| 595 | + data = _da.from_array(data, chunks=cs) |
| 596 | + raster = xr.DataArray( |
| 597 | + data, dims=raster.dims, coords=raster.coords, |
| 598 | + name=raster.name, attrs=raster.attrs, |
| 599 | + ) |
| 600 | + is_dask = True |
587 | 601 |
|
588 | 602 | # Serialize CRS for pickle safety |
589 | 603 | src_wkt = src_crs.to_wkt() |
590 | 604 | tgt_wkt = tgt_crs.to_wkt() |
591 | 605 |
|
592 | | - if is_dask and is_cupy: |
| 606 | + if _use_streaming: |
| 607 | + result_data = _reproject_streaming( |
| 608 | + raster, src_bounds, src_shape, y_desc, |
| 609 | + src_wkt, tgt_wkt, |
| 610 | + out_bounds, out_shape, |
| 611 | + resampling, nd, transform_precision, |
| 612 | + chunk_size or 2048, |
| 613 | + ) |
| 614 | + elif is_dask and is_cupy: |
593 | 615 | result_data = _reproject_dask_cupy( |
594 | 616 | raster, src_bounds, src_shape, y_desc, |
595 | 617 | src_wkt, tgt_wkt, |
@@ -784,6 +806,56 @@ def _reproject_inmemory_cupy( |
784 | 806 | ) |
785 | 807 |
|
786 | 808 |
|
| 809 | +def _reproject_streaming( |
| 810 | + raster, src_bounds, src_shape, y_desc, |
| 811 | + src_wkt, tgt_wkt, |
| 812 | + out_bounds, out_shape, |
| 813 | + resampling, nodata, precision, |
| 814 | + tile_size, |
| 815 | +): |
| 816 | + """Streaming reproject for datasets too large for dask's graph. |
| 817 | +
|
| 818 | + Processes output tiles sequentially in a simple loop: |
| 819 | + 1. For each output tile, compute source coordinates (Numba) |
| 820 | + 2. Read only the needed source window from the (possibly mmap'd) source |
| 821 | + 3. Resample and write the tile into the output array |
| 822 | + 4. Free the tile before processing the next one |
| 823 | +
|
| 824 | + Memory usage is O(tile_size^2), not O(total_pixels). No dask graph |
| 825 | + is created, so there's no graph-size overhead. The output is a numpy |
| 826 | + array assembled tile by tile. |
| 827 | + """ |
| 828 | + if isinstance(tile_size, int): |
| 829 | + tile_size = (tile_size, tile_size) |
| 830 | + |
| 831 | + row_chunks, col_chunks = _compute_chunk_layout(out_shape, tile_size) |
| 832 | + result = np.full(out_shape, nodata, dtype=np.float64) |
| 833 | + |
| 834 | + row_offset = 0 |
| 835 | + for rchunk in row_chunks: |
| 836 | + col_offset = 0 |
| 837 | + for cchunk in col_chunks: |
| 838 | + cb = _chunk_bounds( |
| 839 | + out_bounds, out_shape, |
| 840 | + row_offset, row_offset + rchunk, |
| 841 | + col_offset, col_offset + cchunk, |
| 842 | + ) |
| 843 | + tile = _reproject_chunk_numpy( |
| 844 | + raster.data, |
| 845 | + src_bounds, src_shape, y_desc, |
| 846 | + src_wkt, tgt_wkt, |
| 847 | + cb, (rchunk, cchunk), |
| 848 | + resampling, nodata, precision, |
| 849 | + ) |
| 850 | + result[row_offset:row_offset + rchunk, |
| 851 | + col_offset:col_offset + cchunk] = tile |
| 852 | + del tile # free immediately |
| 853 | + col_offset += cchunk |
| 854 | + row_offset += rchunk |
| 855 | + |
| 856 | + return result |
| 857 | + |
| 858 | + |
787 | 859 | def _reproject_dask_cupy( |
788 | 860 | raster, src_bounds, src_shape, y_desc, |
789 | 861 | src_wkt, tgt_wkt, |
|
0 commit comments