|
21 | 21 | from ._writer import write |
22 | 22 |
|
23 | 23 | __all__ = ['read_geotiff', 'write_geotiff', 'open_cog', 'read_geotiff_dask', |
24 | | - 'read_vrt', 'write_vrt', 'read_geotiff_gpu'] |
| 24 | + 'read_vrt', 'write_vrt', 'read_geotiff_gpu', 'write_geotiff_gpu'] |
25 | 25 |
|
26 | 26 |
|
27 | 27 | def _wkt_to_epsg(wkt_or_proj: str) -> int | None: |
@@ -661,6 +661,116 @@ def read_geotiff_gpu(source: str, *, |
661 | 661 | name=name, attrs=attrs) |
662 | 662 |
|
663 | 663 |
|
| 664 | +def write_geotiff_gpu(data, path: str, *, |
| 665 | + crs: int | str | None = None, |
| 666 | + nodata=None, |
| 667 | + compression: str = 'zstd', |
| 668 | + tile_size: int = 256, |
| 669 | + predictor: bool = False) -> None: |
| 670 | + """Write a CuPy-backed DataArray as a GeoTIFF with GPU compression. |
| 671 | +
|
| 672 | + Tiles are extracted and compressed on the GPU via nvCOMP, then |
| 673 | + assembled into a TIFF file on CPU. The CuPy array stays on device |
| 674 | + throughout compression -- only the compressed bytes transfer to CPU |
| 675 | + for file writing. |
| 676 | +
|
| 677 | + Falls back to CPU compression if nvCOMP is not available. |
| 678 | +
|
| 679 | + Parameters |
| 680 | + ---------- |
| 681 | + data : xr.DataArray (CuPy-backed) or cupy.ndarray |
| 682 | + 2D raster on GPU. |
| 683 | + path : str |
| 684 | + Output file path. |
| 685 | + crs : int, str, or None |
| 686 | + EPSG code or WKT string. |
| 687 | + nodata : float, int, or None |
| 688 | + NoData value. |
| 689 | + compression : str |
| 690 | + 'zstd' (default, fastest on GPU), 'deflate', or 'none'. |
| 691 | + tile_size : int |
| 692 | + Tile size in pixels (default 256). |
| 693 | + predictor : bool |
| 694 | + Apply horizontal differencing predictor. |
| 695 | + """ |
| 696 | + try: |
| 697 | + import cupy |
| 698 | + except ImportError: |
| 699 | + raise ImportError("cupy is required for GPU writes") |
| 700 | + |
| 701 | + from ._gpu_decode import gpu_compress_tiles |
| 702 | + from ._writer import ( |
| 703 | + _compression_tag, _assemble_tiff, _write_bytes, |
| 704 | + GeoTransform as _GT, |
| 705 | + ) |
| 706 | + from ._dtypes import numpy_to_tiff_dtype |
| 707 | + |
| 708 | + # Extract array and metadata |
| 709 | + geo_transform = None |
| 710 | + epsg = None |
| 711 | + raster_type = 1 |
| 712 | + |
| 713 | + if isinstance(crs, int): |
| 714 | + epsg = crs |
| 715 | + elif isinstance(crs, str): |
| 716 | + epsg = _wkt_to_epsg(crs) |
| 717 | + |
| 718 | + if isinstance(data, xr.DataArray): |
| 719 | + arr = data.data # keep as cupy |
| 720 | + if hasattr(arr, 'get'): |
| 721 | + # It's a CuPy array |
| 722 | + pass |
| 723 | + else: |
| 724 | + # Numpy DataArray -- send to GPU |
| 725 | + arr = cupy.asarray(data.values) |
| 726 | + |
| 727 | + geo_transform = _coords_to_transform(data) |
| 728 | + if epsg is None: |
| 729 | + epsg = data.attrs.get('crs') |
| 730 | + if nodata is None: |
| 731 | + nodata = data.attrs.get('nodata') |
| 732 | + if data.attrs.get('raster_type') == 'point': |
| 733 | + raster_type = RASTER_PIXEL_IS_POINT |
| 734 | + else: |
| 735 | + arr = cupy.asarray(data) if not hasattr(data, 'device') else data |
| 736 | + |
| 737 | + if arr.ndim not in (2, 3): |
| 738 | + raise ValueError(f"Expected 2D or 3D array, got {arr.ndim}D") |
| 739 | + |
| 740 | + height, width = arr.shape[:2] |
| 741 | + samples = arr.shape[2] if arr.ndim == 3 else 1 |
| 742 | + np_dtype = np.dtype(str(arr.dtype)) # cupy dtype -> numpy dtype |
| 743 | + |
| 744 | + comp_tag = _compression_tag(compression) |
| 745 | + pred_val = 2 if predictor else 1 |
| 746 | + |
| 747 | + # GPU compress |
| 748 | + compressed_tiles = gpu_compress_tiles( |
| 749 | + arr, tile_size, tile_size, width, height, |
| 750 | + comp_tag, pred_val, np_dtype, samples) |
| 751 | + |
| 752 | + # Build offset/bytecount lists |
| 753 | + rel_offsets = [] |
| 754 | + byte_counts = [] |
| 755 | + offset = 0 |
| 756 | + for tile in compressed_tiles: |
| 757 | + rel_offsets.append(offset) |
| 758 | + byte_counts.append(len(tile)) |
| 759 | + offset += len(tile) |
| 760 | + |
| 761 | + # Assemble TIFF on CPU (only metadata + compressed bytes) |
| 762 | + # _assemble_tiff needs an array in parts[0] to detect samples_per_pixel |
| 763 | + shape_stub = np.empty((1, 1, samples) if samples > 1 else (1, 1), dtype=np_dtype) |
| 764 | + parts = [(shape_stub, width, height, rel_offsets, byte_counts, compressed_tiles)] |
| 765 | + |
| 766 | + file_bytes = _assemble_tiff( |
| 767 | + width, height, np_dtype, comp_tag, predictor, True, tile_size, |
| 768 | + parts, geo_transform, epsg, nodata, is_cog=False, |
| 769 | + raster_type=raster_type) |
| 770 | + |
| 771 | + _write_bytes(file_bytes, path) |
| 772 | + |
| 773 | + |
664 | 774 | def read_vrt(source: str, *, window=None, |
665 | 775 | band: int | None = None, |
666 | 776 | name: str | None = None) -> xr.DataArray: |
|
0 commit comments