diff --git a/README.md b/README.md index 5ad23883..46dbcd14 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,12 @@ to_geotiff(data, 'out.tif', gpu=True) # force GPU compress to_geotiff(data, 'ortho.tif', compression='jpeg') # JPEG for orthophotos write_vrt('mosaic.vrt', ['tile1.tif', 'tile2.tif']) # generate VRT +open_geotiff('dem.tif', dtype='float32') # half memory +open_geotiff('dem.tif', dtype='float32', chunks=512) # Dask + half memory +to_geotiff(data, 'out.tif', compression_level=1) # fast scratch write +to_geotiff(data, 'out.tif', compression_level=22) # max compression +to_geotiff(dask_da, 'mosaic.vrt') # stream Dask to VRT + # Accessor methods da.xrs.to_geotiff('out.tif', compression='lzw') # write from DataArray ds.xrs.open_geotiff('large_dem.tif') # read windowed to Dataset extent diff --git a/docs/superpowers/plans/2026-03-30-geotiff-perf-controls.md b/docs/superpowers/plans/2026-03-30-geotiff-perf-controls.md new file mode 100644 index 00000000..e64d7666 --- /dev/null +++ b/docs/superpowers/plans/2026-03-30-geotiff-perf-controls.md @@ -0,0 +1,813 @@ +# GeoTIFF performance and memory controls implementation plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Add `dtype` to `open_geotiff`, `compression_level` to `to_geotiff`, and VRT tiled output when `to_geotiff` is given a `.vrt` path. Issue #1083. + +**Architecture:** Three independent features threaded into the existing geotiff module. `dtype` intercepts each read path after tile/strip decode. `compression_level` passes through `to_geotiff` → `write()` → `_write_tiled`/`_write_stripped` → `compress()`. VRT output adds a new code path in `to_geotiff` that slices the input into per-chunk GeoTIFFs and calls `write_vrt()`. + +**Tech Stack:** numpy, xarray, dask (optional), numba, cupy (optional). All existing dependencies. + +--- + +## File map + +| File | Role | Changes | +|------|------|---------| +| `xrspatial/geotiff/__init__.py` | Public API | Add `dtype` param to `open_geotiff`, `read_geotiff_dask`, `read_geotiff_gpu`, `_delayed_read_window`. Add `compression_level` param to `to_geotiff`, `write_geotiff_gpu`. Add VRT output path in `to_geotiff`. Add `_validate_dtype_cast()` helper. | +| `xrspatial/geotiff/_writer.py` | Tile/strip compression, file assembly | Thread `compression_level` through `write()`, `_write_tiled()`, `_write_stripped()`, `_prepare_tile()`. | +| `xrspatial/geotiff/_compression.py` | Codec dispatch | No changes needed -- `compress()` already accepts `level`. | +| `xrspatial/geotiff/tests/test_dtype_read.py` | New test file | Tests for `dtype` on eager, dask, validation. | +| `xrspatial/geotiff/tests/test_compression_level.py` | New test file | Tests for `compression_level` round-trips. | +| `xrspatial/geotiff/tests/test_vrt_write.py` | New test file | Tests for `.vrt` output path, dask streaming, numpy slicing, edge cases. | + +--- + +### Task 1: `compression_level` plumbing through the writer + +The simplest of the three features. Thread the level integer from the public API down to the `compress()` call. + +**Files:** +- Modify: `xrspatial/geotiff/_writer.py:298-403` (`_write_stripped`, `_prepare_tile`, `_write_tiled`, `write`) +- Modify: `xrspatial/geotiff/__init__.py:342-519` (`to_geotiff`, `write_geotiff_gpu`) +- Test: `xrspatial/geotiff/tests/test_compression_level.py` (create) + +- [ ] **Step 1: Write the failing test** + +Create `xrspatial/geotiff/tests/test_compression_level.py`: + +```python +"""Tests for compression_level parameter on to_geotiff.""" +import numpy as np +import os +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +@pytest.fixture +def sample_float32(tmp_path): + """100x100 float32 raster with coords and CRS.""" + arr = np.random.default_rng(42).random((100, 100), dtype=np.float32) + y = np.linspace(40.0, 41.0, 100) + x = np.linspace(-105.0, -104.0, 100) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + return da + + +class TestCompressionLevel: + """Round-trip tests: write with level, read back, verify data matches.""" + + def test_zstd_level_1_round_trip(self, sample_float32, tmp_path): + path = str(tmp_path / 'test_1083_zstd_l1.tif') + to_geotiff(sample_float32, path, compression='zstd', + compression_level=1) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, + sample_float32.values, decimal=6) + + def test_zstd_level_22_round_trip(self, sample_float32, tmp_path): + path = str(tmp_path / 'test_1083_zstd_l22.tif') + to_geotiff(sample_float32, path, compression='zstd', + compression_level=22) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, + sample_float32.values, decimal=6) + + def test_deflate_level_1_round_trip(self, sample_float32, tmp_path): + path = str(tmp_path / 'test_1083_deflate_l1.tif') + to_geotiff(sample_float32, path, compression='deflate', + compression_level=1) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, + sample_float32.values, decimal=6) + + def test_deflate_level_9_round_trip(self, sample_float32, tmp_path): + path = str(tmp_path / 'test_1083_deflate_l9.tif') + to_geotiff(sample_float32, path, compression='deflate', + compression_level=9) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, + sample_float32.values, decimal=6) + + def test_higher_level_produces_smaller_file(self, sample_float32, tmp_path): + path_l1 = str(tmp_path / 'test_1083_small_l1.tif') + path_l22 = str(tmp_path / 'test_1083_small_l22.tif') + to_geotiff(sample_float32, path_l1, compression='zstd', + compression_level=1) + to_geotiff(sample_float32, path_l22, compression='zstd', + compression_level=22) + assert os.path.getsize(path_l22) <= os.path.getsize(path_l1) + + def test_level_none_uses_default(self, sample_float32, tmp_path): + path = str(tmp_path / 'test_1083_default.tif') + to_geotiff(sample_float32, path, compression='zstd', + compression_level=None) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, + sample_float32.values, decimal=6) + + def test_level_ignored_for_lzw(self, sample_float32, tmp_path): + """LZW has no level support; setting one should not error.""" + path = str(tmp_path / 'test_1083_lzw_level.tif') + to_geotiff(sample_float32, path, compression='lzw', + compression_level=5) + result = open_geotiff(path) + np.testing.assert_array_almost_equal(result.values, + sample_float32.values, decimal=6) + + def test_invalid_level_raises(self, sample_float32, tmp_path): + path = str(tmp_path / 'test_1083_bad_level.tif') + with pytest.raises(ValueError, match='compression_level'): + to_geotiff(sample_float32, path, compression='zstd', + compression_level=99) + + def test_invalid_deflate_level_raises(self, sample_float32, tmp_path): + path = str(tmp_path / 'test_1083_bad_deflate.tif') + with pytest.raises(ValueError, match='compression_level'): + to_geotiff(sample_float32, path, compression='deflate', + compression_level=10) +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/test_compression_level.py -v --no-header -x 2>&1 | head -30` +Expected: FAIL -- `to_geotiff()` got an unexpected keyword argument `compression_level`. + +- [ ] **Step 3: Add `compression_level` validation to `to_geotiff`** + +In `xrspatial/geotiff/__init__.py`, change the `to_geotiff` signature and add validation before the write call. Add `compression_level: int | None = None` parameter after `compression`. Add this validation block before the `write()` call (before line 499): + +```python + # Validate compression_level + _LEVEL_RANGES = { + 'deflate': (1, 9), 'zstd': (1, 22), 'lz4': (0, 16), + } + if compression_level is not None: + level_range = _LEVEL_RANGES.get(compression) + if level_range is not None: + lo, hi = level_range + if not (lo <= compression_level <= hi): + raise ValueError( + f"compression_level={compression_level} out of range " + f"for {compression} (valid: {lo}-{hi})") +``` + +Pass `compression_level=compression_level` to the `write()` call at line 499. + +- [ ] **Step 4: Thread `compression_level` through `write()` → `_write_tiled` → `_prepare_tile`** + +In `xrspatial/geotiff/_writer.py`: + +1. Add `compression_level: int | None = None` parameter to `write()` (after `predictor`). +2. Pass `compression_level=compression_level` to `_write_tiled()` and `_write_stripped()` calls inside `write()`. +3. Add `compression_level: int | None = None` parameter to `_write_tiled()` and `_write_stripped()`. +4. Add `compression_level: int | None = None` parameter to `_prepare_tile()`. +5. In `_prepare_tile()`, change `return compress(tile_data, compression)` to `return compress(tile_data, compression, level=compression_level)` when `compression_level is not None`, else `return compress(tile_data, compression)`. Simplest: `return compress(tile_data, compression) if compression_level is None else compress(tile_data, compression, level=compression_level)`. +6. In `_write_stripped()`, do the same for the `compress(strip_data, compression)` call at the sequential path. +7. Pass `compression_level` through all `_prepare_tile` call sites in `_write_tiled`. + +The `compress()` function in `_compression.py` already accepts `level` as a keyword argument with default 6, so we just need to pass it when non-None. + +- [ ] **Step 5: Run tests to verify they pass** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/test_compression_level.py -v --no-header 2>&1 | tail -20` +Expected: All PASS. + +- [ ] **Step 6: Commit** + +```bash +cd .claude/worktrees/issue-1083 +git add xrspatial/geotiff/__init__.py xrspatial/geotiff/_writer.py xrspatial/geotiff/tests/test_compression_level.py +git commit -m "Add compression_level parameter to to_geotiff (#1083)" +``` + +--- + +### Task 2: `dtype` parameter on `open_geotiff` (eager and dask paths) + +**Files:** +- Modify: `xrspatial/geotiff/__init__.py:151-636` (`open_geotiff`, `read_geotiff_dask`, `_delayed_read_window`, `read_geotiff_gpu`) +- Test: `xrspatial/geotiff/tests/test_dtype_read.py` (create) + +- [ ] **Step 1: Write the failing test** + +Create `xrspatial/geotiff/tests/test_dtype_read.py`: + +```python +"""Tests for dtype parameter on open_geotiff.""" +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +@pytest.fixture +def float64_tif(tmp_path): + """Write a float64 GeoTIFF for dtype cast tests.""" + arr = np.random.default_rng(99).random((80, 80)).astype(np.float64) + y = np.linspace(40.0, 41.0, 80) + x = np.linspace(-105.0, -104.0, 80) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + path = str(tmp_path / 'test_1083_f64.tif') + to_geotiff(da, path, compression='none') + return path, arr + + +@pytest.fixture +def uint16_tif(tmp_path): + """Write a uint16 GeoTIFF for dtype cast tests.""" + arr = np.random.default_rng(77).integers(0, 10000, (60, 60), + dtype=np.uint16) + y = np.linspace(40.0, 41.0, 60) + x = np.linspace(-105.0, -104.0, 60) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + path = str(tmp_path / 'test_1083_u16.tif') + to_geotiff(da, path, compression='none') + return path, arr + + +class TestDtypeEager: + """dtype parameter on eager (numpy) reads.""" + + def test_float64_to_float32(self, float64_tif): + path, orig = float64_tif + result = open_geotiff(path, dtype='float32') + assert result.dtype == np.float32 + np.testing.assert_array_almost_equal( + result.values, orig.astype(np.float32), decimal=6) + + def test_float64_to_float16(self, float64_tif): + path, orig = float64_tif + result = open_geotiff(path, dtype=np.float16) + assert result.dtype == np.float16 + + def test_uint16_to_int32(self, uint16_tif): + path, orig = uint16_tif + result = open_geotiff(path, dtype='int32') + assert result.dtype == np.int32 + np.testing.assert_array_equal(result.values, orig.astype(np.int32)) + + def test_uint16_to_uint8(self, uint16_tif): + """Narrowing int cast is allowed (user asked for it).""" + path, _ = uint16_tif + result = open_geotiff(path, dtype='uint8') + assert result.dtype == np.uint8 + + def test_float_to_int_raises(self, float64_tif): + path, _ = float64_tif + with pytest.raises(ValueError, match='float.*int'): + open_geotiff(path, dtype='int32') + + def test_dtype_none_preserves_native(self, float64_tif): + path, _ = float64_tif + result = open_geotiff(path, dtype=None) + assert result.dtype == np.float64 + + +class TestDtypeDask: + """dtype parameter on dask reads.""" + + def test_float64_to_float32_dask(self, float64_tif): + path, orig = float64_tif + result = open_geotiff(path, dtype='float32', chunks=40) + assert result.dtype == np.float32 + computed = result.values + np.testing.assert_array_almost_equal( + computed, orig.astype(np.float32), decimal=6) + + def test_chunks_are_target_dtype(self, float64_tif): + path, _ = float64_tif + result = open_geotiff(path, dtype='float32', chunks=40) + # Each chunk should be float32, not float64 + assert result.data.dtype == np.float32 + + def test_float_to_int_raises_dask(self, float64_tif): + path, _ = float64_tif + with pytest.raises(ValueError, match='float.*int'): + open_geotiff(path, dtype='int32', chunks=40) +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/test_dtype_read.py -v --no-header -x 2>&1 | head -20` +Expected: FAIL -- `open_geotiff()` got an unexpected keyword argument `dtype`. + +- [ ] **Step 3: Add `_validate_dtype_cast` helper and `dtype` to `open_geotiff`** + +In `xrspatial/geotiff/__init__.py`, add a helper function after the `_geo_to_coords` function (around line 58): + +```python +def _validate_dtype_cast(source_dtype, target_dtype): + """Validate that casting source_dtype to target_dtype is allowed. + + Raises ValueError for float-to-int casts (lossy in a way users + often don't intend). All other casts are permitted -- the user + asked for them explicitly. + """ + src = np.dtype(source_dtype) + tgt = np.dtype(target_dtype) + if src.kind == 'f' and tgt.kind in ('u', 'i'): + raise ValueError( + f"Cannot cast float ({src}) to int ({tgt}). " + f"This loses fractional data and is usually unintentional. " + f"Cast explicitly after reading if you really want this.") +``` + +Then modify `open_geotiff` signature to add `dtype=None` after `source`. In the eager path (after `arr, geo_info = read_to_array(...)` at line 204), add: + +```python + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(arr.dtype, target) + arr = arr.astype(target) +``` + +Pass `dtype=dtype` through to `read_geotiff_dask()` and `read_geotiff_gpu()` calls. + +- [ ] **Step 4: Add `dtype` to `read_geotiff_dask` and `_delayed_read_window`** + +In `read_geotiff_dask`: +1. Add `dtype` parameter to signature. +2. Before building dask blocks, validate: `if dtype is not None: target = np.dtype(dtype); _validate_dtype_cast(file_dtype, target)` where `file_dtype` is the dtype from the metadata read. +3. If dtype is set, use `target` instead of `dtype` (the file dtype) for `da.from_delayed(..., dtype=target)`. +4. Pass `dtype` to `_delayed_read_window`. + +In `_delayed_read_window`: +1. Add `target_dtype=None` parameter. +2. Inside the `_read()` closure, after the nodata masking, add: `if target_dtype is not None: arr = arr.astype(target_dtype)`. + +- [ ] **Step 5: Add `dtype` to `read_geotiff_gpu`** + +In `read_geotiff_gpu`: +1. Add `dtype` parameter to signature. +2. After the final `arr_gpu` is built (before building the DataArray), add: `if dtype is not None: target = np.dtype(dtype); _validate_dtype_cast(np.dtype(str(arr_gpu.dtype)), target); arr_gpu = arr_gpu.astype(target)`. + +- [ ] **Step 6: Run tests to verify they pass** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/test_dtype_read.py -v --no-header 2>&1 | tail -20` +Expected: All PASS. + +- [ ] **Step 7: Run existing tests to check for regressions** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/ -v --no-header -x -q 2>&1 | tail -20` +Expected: All PASS. + +- [ ] **Step 8: Commit** + +```bash +cd .claude/worktrees/issue-1083 +git add xrspatial/geotiff/__init__.py xrspatial/geotiff/tests/test_dtype_read.py +git commit -m "Add dtype parameter to open_geotiff (#1083)" +``` + +--- + +### Task 3: VRT tiled output from `to_geotiff` + +**Files:** +- Modify: `xrspatial/geotiff/__init__.py:342-519` (`to_geotiff`) +- Test: `xrspatial/geotiff/tests/test_vrt_write.py` (create) + +- [ ] **Step 1: Write the failing tests** + +Create `xrspatial/geotiff/tests/test_vrt_write.py`: + +```python +"""Tests for VRT tiled output from to_geotiff.""" +import numpy as np +import os +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +@pytest.fixture +def sample_raster(): + """200x200 float32 raster with coords and CRS.""" + arr = np.random.default_rng(55).random((200, 200), dtype=np.float32) + y = np.linspace(41.0, 40.0, 200) # north-to-south + x = np.linspace(-106.0, -105.0, 200) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326, 'nodata': -9999.0}) + return da + + +class TestVrtOutputNumpy: + """VRT output from numpy-backed DataArrays.""" + + def test_creates_vrt_and_tiles_dir(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'out_1083.vrt') + to_geotiff(sample_raster, vrt_path) + assert os.path.exists(vrt_path) + tiles_dir = str(tmp_path / 'out_1083_tiles') + assert os.path.isdir(tiles_dir) + tile_files = os.listdir(tiles_dir) + assert len(tile_files) > 0 + assert all(f.endswith('.tif') for f in tile_files) + + def test_round_trip_numpy(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'rt_1083.vrt') + to_geotiff(sample_raster, vrt_path) + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_tile_naming_convention(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'named_1083.vrt') + to_geotiff(sample_raster, vrt_path, tile_size=100) + tiles_dir = str(tmp_path / 'named_1083_tiles') + files = sorted(os.listdir(tiles_dir)) + # 200x200 with tile_size=100 -> 2x2 grid + assert files == [ + 'tile_00_00.tif', 'tile_00_01.tif', + 'tile_01_00.tif', 'tile_01_01.tif', + ] + + def test_relative_paths_in_vrt(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'rel_1083.vrt') + to_geotiff(sample_raster, vrt_path) + with open(vrt_path) as f: + content = f.read() + # Paths should be relative (no leading /) + assert 'rel_1083_tiles/' in content + assert str(tmp_path) not in content + + def test_compression_level_passed_to_tiles(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'cl_1083.vrt') + to_geotiff(sample_raster, vrt_path, compression='zstd', + compression_level=1) + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + +class TestVrtOutputDask: + """VRT output from dask-backed DataArrays.""" + + def test_dask_round_trip(self, sample_raster, tmp_path): + dask_da = sample_raster.chunk({'y': 100, 'x': 100}) + vrt_path = str(tmp_path / 'dask_1083.vrt') + to_geotiff(dask_da, vrt_path) + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_dask_one_tile_per_chunk(self, sample_raster, tmp_path): + dask_da = sample_raster.chunk({'y': 100, 'x': 100}) + vrt_path = str(tmp_path / 'chunks_1083.vrt') + to_geotiff(dask_da, vrt_path) + tiles_dir = str(tmp_path / 'chunks_1083_tiles') + # 200x200 chunked 100x100 -> 2x2 = 4 tiles + assert len(os.listdir(tiles_dir)) == 4 + + +class TestVrtEdgeCases: + """Edge cases and validation.""" + + def test_cog_with_vrt_raises(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'cog_1083.vrt') + with pytest.raises(ValueError, match='cog.*vrt|vrt.*cog'): + to_geotiff(sample_raster, vrt_path, cog=True) + + def test_overview_levels_with_vrt_raises(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'ovr_1083.vrt') + with pytest.raises(ValueError, match='overview.*vrt|vrt.*overview'): + to_geotiff(sample_raster, vrt_path, overview_levels=[2, 4]) + + def test_nonempty_tiles_dir_raises(self, sample_raster, tmp_path): + tiles_dir = tmp_path / 'exist_1083_tiles' + tiles_dir.mkdir() + (tiles_dir / 'dummy.tif').write_text('x') + vrt_path = str(tmp_path / 'exist_1083.vrt') + with pytest.raises(FileExistsError): + to_geotiff(sample_raster, vrt_path) + + def test_empty_tiles_dir_ok(self, sample_raster, tmp_path): + tiles_dir = tmp_path / 'empty_1083_tiles' + tiles_dir.mkdir() + vrt_path = str(tmp_path / 'empty_1083.vrt') + to_geotiff(sample_raster, vrt_path) + assert os.path.exists(vrt_path) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/test_vrt_write.py -v --no-header -x 2>&1 | head -20` +Expected: FAIL -- VRT path not yet handled. + +- [ ] **Step 3: Implement VRT output path in `to_geotiff`** + +In `xrspatial/geotiff/__init__.py`, add the VRT detection and dispatch at the top of `to_geotiff`, right after the docstring and before the GPU dispatch: + +```python + # VRT tiled output + if path.lower().endswith('.vrt'): + if cog: + raise ValueError( + "cog=True is not compatible with VRT output. " + "VRT writes tiled GeoTIFFs, not a single COG.") + if overview_levels is not None: + raise ValueError( + "overview_levels is not compatible with VRT output. " + "VRT tiles do not include overviews.") + _write_vrt_tiled(data, path, + crs=crs, nodata=nodata, + compression=compression, + compression_level=compression_level, + tile_size=tile_size, + predictor=predictor, + bigtiff=bigtiff, + gpu=gpu) + return +``` + +Then add the `_write_vrt_tiled` function (new function in `__init__.py`): + +```python +def _write_vrt_tiled(data, vrt_path: str, *, + crs=None, nodata=None, + compression='zstd', compression_level=None, + tile_size=256, predictor=False, + bigtiff=None, gpu=None): + """Write a DataArray as a directory of tiled GeoTIFFs with a VRT index. + + For dask inputs, each chunk is computed and written independently + so the full array never materialises in RAM. + """ + import os + import math + from ._vrt import write_vrt as _write_vrt_fn + + stem = os.path.splitext(os.path.basename(vrt_path))[0] + tiles_dir = os.path.join(os.path.dirname(vrt_path) or '.', f'{stem}_tiles') + + # Validate tiles directory + if os.path.isdir(tiles_dir) and os.listdir(tiles_dir): + raise FileExistsError( + f"Tiles directory already exists and is not empty: {tiles_dir}") + os.makedirs(tiles_dir, exist_ok=True) + + # Resolve metadata from the DataArray + epsg = None + wkt = None + nodata_val = nodata + geo_transform = None + + if isinstance(data, xr.DataArray): + geo_transform = _coords_to_transform(data) + if crs is None: + crs_attr = data.attrs.get('crs') + if isinstance(crs_attr, str): + epsg = _wkt_to_epsg(crs_attr) + if epsg is None: + wkt = crs_attr + elif crs_attr is not None: + epsg = int(crs_attr) + if epsg is None: + wkt_attr = data.attrs.get('crs_wkt') + if isinstance(wkt_attr, str): + epsg = _wkt_to_epsg(wkt_attr) + if epsg is None: + wkt = wkt_attr + elif isinstance(crs, int): + epsg = crs + elif isinstance(crs, str): + epsg = _wkt_to_epsg(crs) + if epsg is None: + wkt = crs + if nodata_val is None: + nodata_val = data.attrs.get('nodata') + + raw = data.data if isinstance(data, xr.DataArray) else data + is_dask = hasattr(raw, 'dask') + is_cupy = hasattr(raw, 'device') or hasattr(raw, 'get') + + if is_dask: + # Dask path: one tile per chunk + import dask + chunks_y = raw.chunks[0] + chunks_x = raw.chunks[1] + n_rows = len(chunks_y) + n_cols = len(chunks_x) + else: + # Numpy/CuPy path: slice by tile_size + if is_cupy: + arr = raw + else: + arr = np.asarray(raw) + h, w = arr.shape[:2] + n_rows = math.ceil(h / tile_size) + n_cols = math.ceil(w / tile_size) + + pad_width = len(str(max(n_rows, n_cols) - 1)) + tile_paths = [] + + if is_dask: + delayed_writes = [] + row_offset = 0 + for ri, ch_y in enumerate(chunks_y): + col_offset = 0 + for ci, ch_x in enumerate(chunks_x): + tile_name = f'tile_{ri:0{pad_width}d}_{ci:0{pad_width}d}.tif' + tile_path = os.path.join(tiles_dir, tile_name) + tile_paths.append(tile_path) + + # Extract the chunk as a dask array + chunk_slice = raw[ + row_offset:row_offset + ch_y, + col_offset:col_offset + ch_x, + ] + + # Build per-tile geo_transform + tile_gt = None + if geo_transform is not None: + t = geo_transform + tile_gt = GeoTransform( + origin_x=t.origin_x + col_offset * t.pixel_width, + origin_y=t.origin_y + row_offset * t.pixel_height, + pixel_width=t.pixel_width, + pixel_height=t.pixel_height, + ) + + delayed_writes.append( + dask.delayed(_write_single_tile)( + chunk_slice, tile_path, tile_gt, epsg, wkt, + nodata_val, compression, compression_level, + tile_size, predictor, bigtiff)) + + col_offset += ch_x + row_offset += ch_y + + dask.compute(*delayed_writes) + + else: + # Numpy/CuPy: slice and write sequentially + h, w = arr.shape[:2] + for ri in range(n_rows): + for ci in range(n_cols): + r0 = ri * tile_size + c0 = ci * tile_size + r1 = min(r0 + tile_size, h) + c1 = min(c0 + tile_size, w) + + tile_name = f'tile_{ri:0{pad_width}d}_{ci:0{pad_width}d}.tif' + tile_path = os.path.join(tiles_dir, tile_name) + tile_paths.append(tile_path) + + tile_data = arr[r0:r1, c0:c1] + + tile_gt = None + if geo_transform is not None: + t = geo_transform + tile_gt = GeoTransform( + origin_x=t.origin_x + c0 * t.pixel_width, + origin_y=t.origin_y + r0 * t.pixel_height, + pixel_width=t.pixel_width, + pixel_height=t.pixel_height, + ) + + _write_single_tile( + tile_data, tile_path, tile_gt, epsg, wkt, + nodata_val, compression, compression_level, + tile_size, predictor, bigtiff) + + # Generate VRT index with relative paths + write_vrt(vrt_path, tile_paths, relative=True, + nodata=nodata_val) + + +def _write_single_tile(chunk_data, path, geo_transform, epsg, wkt, + nodata, compression, compression_level, + tile_size, predictor, bigtiff): + """Write a single tile GeoTIFF. Used by _write_vrt_tiled.""" + if hasattr(chunk_data, 'compute'): + chunk_data = chunk_data.compute() + if hasattr(chunk_data, 'get'): + chunk_data = chunk_data.get() # CuPy -> numpy + + arr = np.asarray(chunk_data) + + # Auto-promote unsupported dtypes + if arr.dtype == np.float16: + arr = arr.astype(np.float32) + elif arr.dtype == np.bool_: + arr = arr.astype(np.uint8) + + # Restore NaN to nodata sentinel + if nodata is not None and arr.dtype.kind == 'f' and not np.isnan(nodata): + nan_mask = np.isnan(arr) + if nan_mask.any(): + arr = arr.copy() + arr[nan_mask] = arr.dtype.type(nodata) + + write(arr, path, + geo_transform=geo_transform, + crs_epsg=epsg, + crs_wkt=wkt if epsg is None else None, + nodata=nodata, + compression=compression, + tiled=True, + tile_size=tile_size, + predictor=predictor, + compression_level=compression_level, + bigtiff=bigtiff) +``` + +Note: The import of `GeoTransform` is already at the top of `__init__.py` (line 19). The import of `write_vrt` should come from `._vrt`. Adjust the import inside `_write_vrt_tiled` to: `from ._vrt import write_vrt as _write_vrt_fn` and call `_write_vrt_fn(...)`. + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/test_vrt_write.py -v --no-header 2>&1 | tail -30` +Expected: All PASS. + +- [ ] **Step 5: Run full test suite to check for regressions** + +Run: `cd .claude/worktrees/issue-1083 && python -m pytest xrspatial/geotiff/tests/ -v --no-header -q 2>&1 | tail -20` +Expected: All PASS. + +- [ ] **Step 6: Commit** + +```bash +cd .claude/worktrees/issue-1083 +git add xrspatial/geotiff/__init__.py xrspatial/geotiff/tests/test_vrt_write.py +git commit -m "Add VRT tiled output from to_geotiff (#1083)" +``` + +--- + +### Task 4: Update documentation and README + +**Files:** +- Modify: `docs/source/reference/io.rst` (or equivalent -- check for existing geotiff docs) +- Modify: `README.md` + +- [ ] **Step 1: Update API docs** + +Check if `docs/source/reference/` has an entry for `open_geotiff`/`to_geotiff`. If so, no code change needed since the docstrings will auto-generate. If there's a manually maintained parameter list, add `dtype`, `compression_level`, and the `.vrt` extension behaviour. + +- [ ] **Step 2: Update README usage examples** + +In `README.md`, find the GeoTIFF I/O section (around line 140-201 based on the exploration). Add these examples to the existing list: + +```python +open_geotiff('dem.tif', dtype='float32') # half memory +open_geotiff('dem.tif', dtype='float32', chunks=512) # dask + half memory +to_geotiff(data, 'out.tif', compression_level=1) # fast scratch write +to_geotiff(data, 'out.tif', compression_level=22) # max compression +to_geotiff(dask_da, 'mosaic.vrt') # stream dask to VRT +``` + +- [ ] **Step 3: Commit** + +```bash +cd .claude/worktrees/issue-1083 +git add README.md docs/ +git commit -m "Update docs for dtype, compression_level, VRT output (#1083)" +``` + +--- + +### Task 5: User guide notebook + +**Files:** +- Create: `examples/user_guide/46_GeoTIFF_Performance.ipynb` + +- [ ] **Step 1: Create the notebook** + +Create `examples/user_guide/46_GeoTIFF_Performance.ipynb` with these cells: + +1. **Markdown: title** -- "GeoTIFF Performance Controls: dtype, compression_level, and VRT output" +2. **Code: imports** -- `import numpy as np, xarray as xr, os, tempfile` and `from xrspatial.geotiff import open_geotiff, to_geotiff` +3. **Markdown: dtype section** -- explain what `dtype` does and when to use it +4. **Code: create a float64 raster, write it, read back with dtype='float32'** -- show the memory savings (arr.nbytes before and after) +5. **Code: dask dtype** -- same with `chunks=256`, show `.dtype` on the result +6. **Markdown: compression_level section** -- explain the speed/size tradeoff +7. **Code: write same raster at level=1 and level=22** -- compare file sizes and write times with `%%time` +8. **Markdown: VRT output section** -- explain the streaming write and directory layout +9. **Code: create a larger raster, chunk it, write to .vrt** -- show the output directory listing +10. **Code: read the VRT back** -- round-trip verification +11. **Markdown: summary** -- one-paragraph recap + +- [ ] **Step 2: Run the notebook to verify it executes** + +Run: `cd .claude/worktrees/issue-1083 && jupyter nbconvert --to notebook --execute examples/user_guide/46_GeoTIFF_Performance.ipynb --output /dev/null 2>&1 | tail -5` +Expected: No errors. + +- [ ] **Step 3: Commit** + +```bash +cd .claude/worktrees/issue-1083 +git add examples/user_guide/46_GeoTIFF_Performance.ipynb +git commit -m "Add user guide notebook for geotiff performance controls (#1083)" +``` diff --git a/docs/superpowers/specs/2026-03-30-geotiff-perf-controls-design.md b/docs/superpowers/specs/2026-03-30-geotiff-perf-controls-design.md new file mode 100644 index 00000000..761d5f55 --- /dev/null +++ b/docs/superpowers/specs/2026-03-30-geotiff-perf-controls-design.md @@ -0,0 +1,147 @@ +# GeoTIFF Performance and Memory Controls + +Adds three parameters to `open_geotiff` and `to_geotiff` that let callers +control memory usage, compression speed, and large-raster write strategy. +All three are opt-in; default behaviour is unchanged. + +## 1. `dtype` parameter on `open_geotiff` + +### API + +```python +open_geotiff(source, *, dtype=None, ...) +``` + +`dtype` accepts any numpy dtype string or object (`np.float32`, `'float32'`, +etc.). `None` preserves the file's native dtype (current behaviour). + +### Read paths + +| Path | Behaviour | +|------|-----------| +| Eager (numpy) | Output array allocated at target dtype. Each decoded tile/strip cast before copy-in. Peak overhead: one tile at native dtype. | +| Dask | Each delayed chunk function casts after decode. Output chunks are target dtype. Same per-tile overhead. | +| GPU (CuPy) | Cast on device after decode. | +| Dask + CuPy | Combination of dask and GPU paths. | + +### Numba LZW fast path + +The LZW decoder is a numba JIT function that emits values one at a time into a +byte buffer. A variant will decode each value and cast inline to the target +dtype so the per-tile buffer is never allocated at native dtype. Other codecs +(deflate, zstd) return byte buffers from C libraries where per-value +interception isn't possible, so those fall back to the tile-level cast. + +### Validation + +- Narrowing float casts (float64 to float32): allowed. +- Narrowing int casts (int64 to int16): allowed (user asked for it explicitly). +- Widening casts (float32 to float64, uint8 to int32): allowed. +- Float to int: `ValueError` (lossy in a way users often don't intend). +- Unsupported casts (e.g. complex128 to uint8): `ValueError`. + +## 2. `compression_level` parameter on `to_geotiff` + +### API + +```python +to_geotiff(data, path, *, compression='zstd', compression_level=None, ...) +``` + +`compression_level` is `int | None`. `None` uses the codec's existing default. + +### Ranges + +| Codec | Range | Default | Direction | +|-------|-------|---------|-----------| +| deflate | 1 -- 9 | 6 | 1 = fastest, 9 = smallest | +| zstd | 1 -- 22 | 3 | 1 = fastest, 22 = smallest | +| lz4 | 0 -- 16 | 0 | 0 = fastest | +| lzw | n/a | n/a | No level support; ignored silently | +| jpeg | n/a | n/a | Quality is a separate axis; ignored | +| packbits | n/a | n/a | Ignored | +| none | n/a | n/a | Ignored | + +### Plumbing + +`to_geotiff` passes `compression_level` to `write()`, which passes it to +`compress()`. The internal `compress()` already accepts a `level` argument; we +just thread it through the two intermediate call sites that currently hardcode +it. + +### Validation + +- Out-of-range level for a codec that supports levels: `ValueError`. +- Level set for a codec without level support: silently ignored. + +### GPU path + +`write_geotiff_gpu` also accepts and forwards the level to nvCOMP batch +compression, which supports levels for zstd and deflate. + +## 3. VRT output from `to_geotiff` via `.vrt` extension + +### Trigger + +When `path` ends in `.vrt`, `to_geotiff` writes a tiled VRT instead of a +monolithic TIFF. No new parameter needed. + +### Output layout + +``` +output.vrt +output_tiles/ + tile_0000_0000.tif # row_col, zero-padded + tile_0000_0001.tif + ... +``` + +Directory name derived from the VRT stem (`foo.vrt` -> `foo_tiles/`). +Zero-padding width scales to the grid dimensions. + +### Behaviour per input type + +| Input | Tiling strategy | Memory profile | +|-------|----------------|----------------| +| Dask DataArray | One tile per dask chunk. Each task computes its chunk and writes one `.tif`. | One chunk in RAM at a time (scheduler controlled). | +| Dask + CuPy | Same, GPU compress per tile. | One chunk in GPU memory at a time. | +| Numpy / ndarray | Slice into `tile_size`-sized pieces, write each. | Source array already in RAM; tile slices are views (no duplication). | +| CuPy | Same as numpy but GPU compress. | Source on GPU; tiles are views. | + +### Per-tile properties + +- Same `compression`, `compression_level`, `predictor`, `nodata`, `crs` as the + parent call. +- `tiled=True` with the caller's `tile_size` (internal TIFF tiling within each + chunk-file). +- GeoTransform adjusted to each tile's spatial position (row/col offset from + the full raster origin). +- No COG overviews on individual tiles. + +### VRT generation + +After all tiles are written, call `write_vrt()` with relative paths. The VRT +XML references each tile by its spatial extent and band mapping. + +### Edge cases and validation + +- `cog=True` with a `.vrt` path: `ValueError` (mutually exclusive). +- Tiles directory exists and is non-empty: `FileExistsError` to prevent silent + overwrites. +- Tiles directory doesn't exist: created automatically. +- `overview_levels` with `.vrt` path: `ValueError` (overviews don't apply). + +### Dask scheduling + +For dask inputs, all delayed tile-write tasks are submitted to +`dask.compute()` at once. The scheduler manages parallelism and memory. Each +task is: compute chunk, compress, write tile file. No coordination between +tasks. + +## Out of scope + +- Streaming write of a monolithic `.tif` from dask input (tracked as a separate + issue). Users who need a single file from a large dask array can write to VRT + and convert externally, or ensure sufficient RAM. +- JPEG quality parameter (separate concern from compression level). +- Automatic chunk-size recommendation based on available memory. diff --git a/examples/user_guide/46_GeoTIFF_Performance.ipynb b/examples/user_guide/46_GeoTIFF_Performance.ipynb new file mode 100644 index 00000000..bb8b06d1 --- /dev/null +++ b/examples/user_guide/46_GeoTIFF_Performance.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GeoTIFF Performance Controls\n", + "\n", + "Three opt-in controls for memory, compression speed, and large-raster writes:\n", + "\n", + "- **`dtype` on read** — cast to a narrower type at load time to reduce memory use\n", + "- **`compression_level` on write** — trade write speed for file size (or vice versa)\n", + "- **VRT tiled output** — stream a chunked dask array to a directory of tiles without loading it all at once" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import xarray as xr\n", + "import os\n", + "import tempfile\n", + "from xrspatial.geotiff import open_geotiff, to_geotiff" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a 1000x1000 float64 DEM-like raster (~8MB)\n", + "rng = np.random.default_rng(42)\n", + "elevation = rng.normal(loc=500, scale=100, size=(1000, 1000)).astype(np.float64)\n", + "y = np.linspace(40.0, 41.0, 1000)\n", + "x = np.linspace(-106.0, -105.0, 1000)\n", + "dem = xr.DataArray(elevation, dims=['y', 'x'],\n", + " coords={'y': y, 'x': x},\n", + " attrs={'crs': 4326})\n", + "print(f\"DEM shape: {dem.shape}, dtype: {dem.dtype}, size: {dem.nbytes / 1e6:.1f} MB\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. dtype on read\n", + "\n", + "Pass `dtype` to `open_geotiff` to cast the raster to a narrower type at load time.\n", + "Reading a float64 file as float32 halves memory use without any extra copy.\n", + "The cast happens inside rasterio before the array reaches Python, so it works on\n", + "all read paths: eager, dask, and GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with tempfile.TemporaryDirectory() as tmpdir:\n", + " path = os.path.join(tmpdir, 'dem_f64.tif')\n", + " to_geotiff(dem, path)\n", + "\n", + " # Read at native dtype (float64)\n", + " native = open_geotiff(path)\n", + " print(f\"Native: dtype={native.dtype}, size={native.nbytes / 1e6:.1f} MB\")\n", + "\n", + " # Read as float32 -- half the memory\n", + " downcast = open_geotiff(path, dtype='float32')\n", + " print(f\"Downcast: dtype={downcast.dtype}, size={downcast.nbytes / 1e6:.1f} MB\")\n", + "\n", + " # Works with dask too\n", + " dask_f32 = open_geotiff(path, dtype='float32', chunks=256)\n", + " print(f\"Dask chunks dtype: {dask_f32.dtype}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. compression_level\n", + "\n", + "`to_geotiff` accepts a `compression_level` argument alongside `compression`.\n", + "For zstd the range is 1–22 (1 = fastest, 22 = smallest file).\n", + "For deflate the range is 1–9.\n", + "The default is the codec's own default when `compression_level` is omitted.\n", + "\n", + "Use a low level when write speed matters (streaming pipelines, scratch files).\n", + "Use a high level for archival or network transfer where file size dominates." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "with tempfile.TemporaryDirectory() as tmpdir:\n", + " results = []\n", + " for level in [1, 3, 10, 22]:\n", + " path = os.path.join(tmpdir, f'dem_zstd_l{level}.tif')\n", + " t0 = time.perf_counter()\n", + " to_geotiff(dem, path, compression='zstd', compression_level=level)\n", + " elapsed = time.perf_counter() - t0\n", + " size_kb = os.path.getsize(path) / 1024\n", + " results.append((level, elapsed, size_kb))\n", + " print(f\" level={level:2d} time={elapsed:.3f}s size={size_kb:.0f} KB\")\n", + "\n", + " print(f\"\\nLevel 1 vs 22: {results[0][2]/results[-1][2]:.1f}x size difference\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. VRT tiled output\n", + "\n", + "Pass a `.vrt` path to `to_geotiff` and it writes a directory of GeoTIFF tiles\n", + "plus a VRT index file that GDAL treats as a single dataset.\n", + "\n", + "Each tile corresponds to one dask chunk and is written independently, so only\n", + "one chunk is in memory at a time. This makes it practical to write arrays that\n", + "are larger than RAM.\n", + "\n", + "The VRT uses relative paths, so the whole output directory is portable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with tempfile.TemporaryDirectory() as tmpdir:\n", + " # Chunk the DEM for dask processing\n", + " dask_dem = dem.chunk({'y': 500, 'x': 500})\n", + "\n", + " vrt_path = os.path.join(tmpdir, 'tiled_dem.vrt')\n", + " to_geotiff(dask_dem, vrt_path, compression='zstd')\n", + "\n", + " # Show what was created\n", + " tiles_dir = os.path.join(tmpdir, 'tiled_dem_tiles')\n", + " print(\"Files created:\")\n", + " print(f\" {os.path.basename(vrt_path)}\")\n", + " for f in sorted(os.listdir(tiles_dir)):\n", + " size = os.path.getsize(os.path.join(tiles_dir, f)) / 1024\n", + " print(f\" tiled_dem_tiles/{f} ({size:.0f} KB)\")\n", + "\n", + " # Read it back via VRT\n", + " result = open_geotiff(vrt_path)\n", + " print(f\"\\nRound-trip: shape={result.shape}, dtype={result.dtype}\")\n", + " print(f\"Max difference: {float(np.abs(result.values - dem.values).max()):.2e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "| Feature | Parameter | Where to use |\n", + "|---|---|---|\n", + "| dtype cast | `open_geotiff(..., dtype='float32')` | Reduce read memory by half |\n", + "| compression level | `to_geotiff(..., compression_level=1)` | Fast scratch writes; set high for archival |\n", + "| VRT tiled output | `to_geotiff(..., 'out.vrt')` | Stream large dask arrays to disk without OOM |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index c5da8f8d..c0fe3044 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -57,6 +57,22 @@ def _geo_to_coords(geo_info, height: int, width: int) -> dict: return {'y': y, 'x': x} +def _validate_dtype_cast(source_dtype, target_dtype): + """Validate that casting source_dtype to target_dtype is allowed. + + Raises ValueError for float-to-int casts (lossy in a way users + often don't intend). All other casts are permitted -- the user + asked for them explicitly. + """ + src = np.dtype(source_dtype) + tgt = np.dtype(target_dtype) + if src.kind == 'f' and tgt.kind in ('u', 'i'): + raise ValueError( + f"Cannot cast float ({src}) to int ({tgt}). " + f"This loses fractional data and is usually unintentional. " + f"Cast explicitly after reading if you really want this.") + + def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None: """Infer GeoTransform from DataArray coordinates. @@ -148,7 +164,7 @@ def _extent_to_window(transform, file_height, file_width, -def open_geotiff(source: str, *, window=None, +def open_geotiff(source: str, *, dtype=None, window=None, overview_level: int | None = None, band: int | None = None, name: str | None = None, @@ -168,6 +184,10 @@ def open_geotiff(source: str, *, window=None, ---------- source : str File path, HTTP URL, or cloud URI (s3://, gs://, az://). + dtype : str, numpy.dtype, or None + Cast the result to this dtype after reading. None keeps the + file's native dtype. Float-to-int casts raise ValueError to + prevent accidental data loss. window : tuple or None (row_start, col_start, row_stop, col_stop) for windowed reading. overview_level : int or None @@ -188,17 +208,18 @@ def open_geotiff(source: str, *, window=None, """ # VRT files if source.lower().endswith('.vrt'): - return read_vrt(source, window=window, band=band, name=name, - chunks=chunks, gpu=gpu) + return read_vrt(source, dtype=dtype, window=window, band=band, + name=name, chunks=chunks, gpu=gpu) # GPU path if gpu: - return read_geotiff_gpu(source, overview_level=overview_level, + return read_geotiff_gpu(source, dtype=dtype, + overview_level=overview_level, name=name, chunks=chunks) # Dask path (CPU) if chunks is not None: - return read_geotiff_dask(source, chunks=chunks, + return read_geotiff_dask(source, dtype=dtype, chunks=chunks, overview_level=overview_level, name=name) arr, geo_info = read_to_array( @@ -306,6 +327,11 @@ def open_geotiff(source: str, *, window=None, arr = arr.astype(np.float64) arr[mask] = np.nan + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(arr.dtype, target) + arr = arr.astype(target) + if arr.ndim == 3: dims = ['y', 'x', 'band'] coords['band'] = np.arange(arr.shape[2]) @@ -339,10 +365,18 @@ def _is_gpu_data(data) -> bool: return isinstance(data, _cupy_type) +_LEVEL_RANGES = { + 'deflate': (1, 9), + 'zstd': (1, 22), + 'lz4': (0, 16), +} + + def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, crs: int | str | None = None, nodata=None, compression: str = 'zstd', + compression_level: int | None = None, tiled: bool = True, tile_size: int = 256, predictor: bool = False, @@ -377,6 +411,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, JPEG is lossy and only supports uint8 data (1 or 3 bands). With ``gpu=True``, JPEG uses nvJPEG for GPU-accelerated encode/decode when available, falling back to Pillow on CPU. + compression_level : int or None + Compression effort level. None uses each codec's default (6 for + deflate/zstd). Valid ranges: deflate 1-9, zstd 1-22, lz4 0-16. + Codecs without a level concept (lzw, packbits, jpeg) accept any + value and ignore it. tiled : bool Use tiled layout (default True). tile_size : int @@ -393,12 +432,33 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, gpu : bool or None Force GPU compression. None (default) auto-detects CuPy data. """ + # VRT tiled output + if path.lower().endswith('.vrt'): + if cog: + raise ValueError( + "cog=True is not compatible with VRT output. " + "VRT writes tiled GeoTIFFs, not a single COG.") + if overview_levels is not None: + raise ValueError( + "overview_levels is not compatible with VRT output. " + "VRT tiles do not include overviews.") + _write_vrt_tiled(data, path, + crs=crs, nodata=nodata, + compression=compression, + compression_level=compression_level, + tile_size=tile_size, + predictor=predictor, + bigtiff=bigtiff) + return + # Auto-detect GPU data and dispatch to write_geotiff_gpu use_gpu = gpu if gpu is not None else _is_gpu_data(data) if use_gpu: try: write_geotiff_gpu(data, path, crs=crs, nodata=nodata, - compression=compression, tile_size=tile_size, + compression=compression, + compression_level=compression_level, + tile_size=tile_size, predictor=predictor) return except (ImportError, Exception): @@ -496,6 +556,16 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, arr = arr.copy() arr[nan_mask] = arr.dtype.type(nodata) + # Validate compression_level against codec-specific range + if compression_level is not None: + level_range = _LEVEL_RANGES.get(compression.lower()) + if level_range is not None: + lo, hi = level_range + if not (lo <= compression_level <= hi): + raise ValueError( + f"compression_level={compression_level} out of range " + f"for {compression} (valid: {lo}-{hi})") + write( arr, path, geo_transform=geo_transform, @@ -503,6 +573,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, crs_wkt=wkt_fallback if epsg is None else None, nodata=nodata, compression=compression, + compression_level=compression_level, tiled=tiled, tile_size=tile_size, predictor=predictor, @@ -519,7 +590,207 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, ) -def read_geotiff_dask(source: str, *, chunks: int | tuple = 512, +def _write_single_tile(chunk_data, path, geo_transform, epsg, wkt, + nodata, compression, compression_level, + tile_size, predictor, bigtiff): + """Write a single tile GeoTIFF. Used by _write_vrt_tiled.""" + if hasattr(chunk_data, 'compute'): + chunk_data = chunk_data.compute() + if hasattr(chunk_data, 'get'): + chunk_data = chunk_data.get() # CuPy -> numpy + + arr = np.asarray(chunk_data) + + # Auto-promote unsupported dtypes + if arr.dtype == np.float16: + arr = arr.astype(np.float32) + elif arr.dtype == np.bool_: + arr = arr.astype(np.uint8) + + # Restore NaN to nodata sentinel + if nodata is not None and arr.dtype.kind == 'f' and not np.isnan(nodata): + nan_mask = np.isnan(arr) + if nan_mask.any(): + arr = arr.copy() + arr[nan_mask] = arr.dtype.type(nodata) + + write(arr, path, + geo_transform=geo_transform, + crs_epsg=epsg, + crs_wkt=wkt if epsg is None else None, + nodata=nodata, + compression=compression, + tiled=True, + tile_size=tile_size, + predictor=predictor, + compression_level=compression_level, + bigtiff=bigtiff) + + +def _write_vrt_tiled(data, vrt_path, *, crs=None, nodata=None, + compression='zstd', compression_level=None, + tile_size=256, predictor=False, bigtiff=None): + """Write a DataArray as a directory of tiled GeoTIFFs with a VRT index. + + This enables streaming dask arrays to disk without materializing the + full array in RAM. + """ + import os + + # Validate compression_level against codec-specific range + if compression_level is not None: + level_range = _LEVEL_RANGES.get(compression.lower()) + if level_range is not None: + lo, hi = level_range + if not (lo <= compression_level <= hi): + raise ValueError( + f"compression_level={compression_level} out of range " + f"for {compression} (valid: {lo}-{hi})") + + # Derive tiles directory from VRT path stem + vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) + stem = os.path.splitext(os.path.basename(vrt_path))[0] + tiles_dir_name = stem + '_tiles' + tiles_dir = os.path.join(vrt_dir, tiles_dir_name) + + # Validate tiles directory + if os.path.isdir(tiles_dir) and os.listdir(tiles_dir): + raise FileExistsError( + f"Tiles directory already contains files: {tiles_dir}") + os.makedirs(tiles_dir, exist_ok=True) + + # Resolve CRS + epsg = None + wkt_fallback = None + if isinstance(crs, int): + epsg = crs + elif isinstance(crs, str): + epsg = _wkt_to_epsg(crs) + if epsg is None: + wkt_fallback = crs + + geo_transform = None + + if isinstance(data, xr.DataArray): + raw = data.data + if epsg is None and crs is None: + crs_attr = data.attrs.get('crs') + if isinstance(crs_attr, str): + epsg = _wkt_to_epsg(crs_attr) + if epsg is None and wkt_fallback is None: + wkt_fallback = crs_attr + elif crs_attr is not None: + epsg = int(crs_attr) + if epsg is None: + wkt = data.attrs.get('crs_wkt') + if isinstance(wkt, str): + epsg = _wkt_to_epsg(wkt) + if epsg is None and wkt_fallback is None: + wkt_fallback = wkt + if nodata is None: + nodata = data.attrs.get('nodata') + geo_transform = _coords_to_transform(data) + else: + raw = data + + # Check for dask backing + is_dask = hasattr(raw, 'dask') + + if is_dask: + if raw.ndim != 2: + raise ValueError( + "VRT tiled output currently supports 2D arrays only, " + f"got {raw.ndim}D. Squeeze or select a band first.") + # Use dask chunk grid + import dask + row_chunks = raw.chunks[0] # tuple of chunk sizes along y + col_chunks = raw.chunks[1] # tuple of chunk sizes along x + n_row_tiles = len(row_chunks) + n_col_tiles = len(col_chunks) + else: + # Numpy: tile using tile_size + if hasattr(raw, 'get'): + np_arr = raw.get() # CuPy + elif hasattr(raw, 'compute'): + np_arr = raw.compute() + else: + np_arr = np.asarray(raw) + if np_arr.ndim != 2: + raise ValueError( + "VRT tiled output currently supports 2D arrays only, " + f"got {np_arr.ndim}D. Squeeze or select a band first.") + height, width = np_arr.shape[:2] + n_row_tiles = (height + tile_size - 1) // tile_size + n_col_tiles = (width + tile_size - 1) // tile_size + + # Zero-padding width for tile names + pad_width = max(2, len(str(max(n_row_tiles, n_col_tiles) - 1))) + + tile_paths = [] + delayed_tasks = [] + + row_offset = 0 + for ri in range(n_row_tiles): + if is_dask: + chunk_h = row_chunks[ri] + else: + chunk_h = min(tile_size, height - row_offset) + + col_offset = 0 + for ci in range(n_col_tiles): + if is_dask: + chunk_w = col_chunks[ci] + else: + chunk_w = min(tile_size, width - col_offset) + + tile_name = f'tile_{ri:0{pad_width}d}_{ci:0{pad_width}d}.tif' + tile_path = os.path.join(tiles_dir, tile_name) + tile_paths.append(tile_path) + + # Compute per-tile geo_transform + tile_gt = None + if geo_transform is not None: + tile_gt = GeoTransform( + origin_x=geo_transform.origin_x + col_offset * geo_transform.pixel_width, + origin_y=geo_transform.origin_y + row_offset * geo_transform.pixel_height, + pixel_width=geo_transform.pixel_width, + pixel_height=geo_transform.pixel_height, + ) + + if is_dask: + # Slice the dask array for this chunk + r_end = row_offset + chunk_h + c_end = col_offset + chunk_w + chunk_data = raw[row_offset:r_end, col_offset:c_end] + + task = dask.delayed(_write_single_tile)( + chunk_data, tile_path, tile_gt, epsg, wkt_fallback, + nodata, compression, compression_level, + tile_size, predictor, bigtiff) + delayed_tasks.append(task) + else: + # Numpy: slice and write directly + chunk_data = np_arr[row_offset:row_offset + chunk_h, + col_offset:col_offset + chunk_w] + _write_single_tile( + chunk_data, tile_path, tile_gt, epsg, wkt_fallback, + nodata, compression, compression_level, + tile_size, predictor, bigtiff) + + col_offset += chunk_w + row_offset += chunk_h + + # Execute all dask tasks + if delayed_tasks: + import dask + dask.compute(*delayed_tasks, scheduler='synchronous') + + # Write VRT index with relative paths + from ._vrt import write_vrt as _write_vrt_fn + _write_vrt_fn(vrt_path, tile_paths, relative=True, nodata=nodata) + + +def read_geotiff_dask(source: str, *, dtype=None, chunks: int | tuple = 512, overview_level: int | None = None, name: str | None = None) -> xr.DataArray: """Read a GeoTIFF as a dask-backed DataArray for out-of-core processing. @@ -530,6 +801,9 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512, ---------- source : str File path. + dtype : str, numpy.dtype, or None + Cast each chunk to this dtype after reading. None keeps the + file's native dtype. Float-to-int casts raise ValueError. chunks : int or (row_chunk, col_chunk) tuple Chunk size in pixels. Default 512. overview_level : int or None @@ -546,13 +820,27 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512, # VRT files: delegate to read_vrt which handles chunks if source.lower().endswith('.vrt'): - return read_vrt(source, name=name, chunks=chunks) + return read_vrt(source, dtype=dtype, name=name, chunks=chunks) # First, do a metadata-only read to get shape, dtype, coords, attrs arr, geo_info = read_to_array(source, overview_level=overview_level) full_h, full_w = arr.shape[:2] n_bands = arr.shape[2] if arr.ndim == 3 else 0 - dtype = arr.dtype + file_dtype = arr.dtype + nodata = geo_info.nodata + + # Nodata masking promotes integer arrays to float64 (for NaN). + # Validate against the effective dtype, not the raw file dtype. + if nodata is not None and file_dtype.kind in ('u', 'i'): + effective_dtype = np.dtype('float64') + else: + effective_dtype = file_dtype + + if dtype is not None: + target_dtype = np.dtype(dtype) + _validate_dtype_cast(effective_dtype, target_dtype) + else: + target_dtype = effective_dtype coords = _geo_to_coords(geo_info, full_h, full_w) @@ -565,8 +853,8 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512, attrs['crs'] = geo_info.crs_epsg if geo_info.raster_type == RASTER_PIXEL_IS_POINT: attrs['raster_type'] = 'point' - if geo_info.nodata is not None: - attrs['nodata'] = geo_info.nodata + if nodata is not None: + attrs['nodata'] = nodata if isinstance(chunks, int): ch_h = ch_w = chunks @@ -593,10 +881,11 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512, block_shape = (r1 - r0, c1 - c0) block = da.from_delayed( _delayed_read_window(source, r0, c0, r1, c1, - overview_level, geo_info.nodata, - dtype, band_arg), + overview_level, nodata, + band_arg, + target_dtype=target_dtype if dtype is not None else None), shape=block_shape, - dtype=dtype, + dtype=target_dtype, ) dask_cols.append(block) dask_rows.append(da.concatenate(dask_cols, axis=1)) @@ -615,7 +904,7 @@ def read_geotiff_dask(source: str, *, chunks: int | tuple = 512, def _delayed_read_window(source, r0, c0, r1, c1, overview_level, nodata, - dtype, band): + band, *, target_dtype=None): """Dask-delayed function to read a single window.""" import dask @dask.delayed @@ -631,11 +920,14 @@ def _read(): if mask.any(): arr = arr.astype(np.float64) arr[mask] = np.nan + if target_dtype is not None: + arr = arr.astype(target_dtype) return arr return _read() def read_geotiff_gpu(source: str, *, + dtype=None, overview_level: int | None = None, name: str | None = None, chunks: int | tuple | None = None) -> xr.DataArray: @@ -699,7 +991,7 @@ def read_geotiff_gpu(source: str, *, bps = ifd.bits_per_sample if isinstance(bps, tuple): bps = bps[0] - dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) + file_dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) geo_info = extract_geo_info(ifd, data, header.byte_order) if not ifd.is_tiled: @@ -714,6 +1006,10 @@ def read_geotiff_gpu(source: str, *, attrs = {} if geo_info.crs_epsg is not None: attrs['crs'] = geo_info.crs_epsg + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(np.dtype(str(arr_gpu.dtype)), target) + arr_gpu = arr_gpu.astype(target) return xr.DataArray(arr_gpu, dims=['y', 'x'], coords=coords, name=name, attrs=attrs) @@ -738,7 +1034,7 @@ def read_geotiff_gpu(source: str, *, arr_gpu = gpu_decode_tiles_from_file( source, offsets, byte_counts, tw, th, width, height, - compression, predictor, dtype, samples, + compression, predictor, file_dtype, samples, ) except Exception: pass @@ -760,13 +1056,18 @@ def read_geotiff_gpu(source: str, *, arr_gpu = gpu_decode_tiles( compressed_tiles, tw, th, width, height, - compression, predictor, dtype, samples, + compression, predictor, file_dtype, samples, ) except (ValueError, Exception): # Unsupported compression -- fall back to CPU then transfer arr_cpu, _ = read_to_array(source, overview_level=overview_level) arr_gpu = cupy.asarray(arr_cpu) + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(np.dtype(str(arr_gpu.dtype)), target) + arr_gpu = arr_gpu.astype(target) + # Build DataArray if name is None: import os @@ -803,6 +1104,7 @@ def write_geotiff_gpu(data, path: str, *, crs: int | str | None = None, nodata=None, compression: str = 'zstd', + compression_level: int | None = None, tile_size: int = 256, predictor: bool = False) -> None: """Write a CuPy-backed DataArray as a GeoTIFF with GPU compression. @@ -827,6 +1129,9 @@ def write_geotiff_gpu(data, path: str, *, compression : str 'zstd' (default, fastest on GPU), 'deflate', 'jpeg', or 'none'. JPEG uses nvJPEG when available, falling back to Pillow. + compression_level : int or None + Compression effort level. Accepted for API compatibility but + currently ignored -- nvCOMP does not expose level control. tile_size : int Tile size in pixels (default 256). predictor : bool @@ -919,7 +1224,7 @@ def write_geotiff_gpu(data, path: str, *, _write_bytes(file_bytes, path) -def read_vrt(source: str, *, window=None, +def read_vrt(source: str, *, dtype=None, window=None, band: int | None = None, name: str | None = None, chunks: int | tuple | None = None, @@ -933,6 +1238,9 @@ def read_vrt(source: str, *, window=None, ---------- source : str Path to the .vrt file. + dtype : str, numpy.dtype, or None + Cast the result to this dtype after reading. None keeps the + file's native dtype. Float-to-int casts raise ValueError. window : tuple or None (row_start, col_start, row_stop, col_stop) for windowed reading. band : int or None @@ -991,6 +1299,11 @@ def read_vrt(source: str, *, window=None, import cupy arr = cupy.asarray(arr) + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(np.dtype(str(arr.dtype)), target) + arr = arr.astype(target) + if arr.ndim == 3: dims = ['y', 'x', 'band'] coords['band'] = np.arange(arr.shape[2]) diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py index 2f41e1a5..dffe2f0f 100644 --- a/xrspatial/geotiff/_compression.py +++ b/xrspatial/geotiff/_compression.py @@ -997,7 +997,8 @@ def compress(data: bytes, compression: int, level: int = 6) -> bytes: compression : int TIFF compression tag value. level : int - Compression level (for deflate). + Compression level (deflate: 1-9, zstd: 1-22, lz4: 0-16). + Ignored for codecs without level support. Returns ------- diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index 8a6f2671..ca817f14 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -456,6 +456,8 @@ def write_vrt(vrt_path: str, source_files: list[str], *, if relative: try: fname = os.path.relpath(fname, vrt_dir) + # VRT XML uses forward slashes regardless of platform + fname = fname.replace('\\', '/') rel_attr = '1' except ValueError: pass # different drives on Windows diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index eabe8695..c0fa5133 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -296,7 +296,8 @@ def _build_ifd(tags: list[tuple], overflow_base: int, # --------------------------------------------------------------------------- def _write_stripped(data: np.ndarray, compression: int, predictor: bool, - rows_per_strip: int = 256) -> tuple[list, list, list]: + rows_per_strip: int = 256, + compression_level: int | None = None) -> tuple[list, list, list]: """Compress data as strips. Returns @@ -329,7 +330,10 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool, buf = strip_arr.view(np.uint8).ravel().copy() buf = predictor_encode(buf, width, strip_rows, bytes_per_sample * samples) strip_data = buf.tobytes() - compressed = compress(strip_data, compression) + if compression_level is None: + compressed = compress(strip_data, compression) + else: + compressed = compress(strip_data, compression, level=compression_level) else: strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() @@ -341,8 +345,10 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool, from ._compression import lerc_compress compressed = lerc_compress( strip_data, width, strip_rows, samples=samples, dtype=dtype) - else: + elif compression_level is None: compressed = compress(strip_data, compression) + else: + compressed = compress(strip_data, compression, level=compression_level) rel_offsets.append(current_offset) byte_counts.append(len(compressed)) @@ -357,7 +363,8 @@ def _write_stripped(data: np.ndarray, compression: int, predictor: bool, # --------------------------------------------------------------------------- def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype, - bytes_per_sample, predictor, compression): + bytes_per_sample, predictor, compression, + compression_level=None): """Extract, pad, and compress a single tile. Thread-safe.""" r0 = tr * th c0 = tc * tw @@ -400,11 +407,14 @@ def _prepare_tile(data, tr, tc, th, tw, height, width, samples, dtype, from ._compression import lerc_compress return lerc_compress( tile_data, tw, th, samples=samples, dtype=dtype) - return compress(tile_data, compression) + if compression_level is None: + return compress(tile_data, compression) + return compress(tile_data, compression, level=compression_level) def _write_tiled(data: np.ndarray, compression: int, predictor: bool, - tile_size: int = 256) -> tuple[list, list, list]: + tile_size: int = 256, + compression_level: int | None = None) -> tuple[list, list, list]: """Compress data as tiles, using parallel compression. For compressed formats (deflate, lzw, zstd), tiles are compressed @@ -477,6 +487,7 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: bool, compressed = _prepare_tile( data, tr, tc, th, tw, height, width, samples, dtype, bytes_per_sample, predictor, compression, + compression_level, ) rel_offsets.append(current_offset) byte_counts.append(len(compressed)) @@ -497,6 +508,7 @@ def _write_tiled(data: np.ndarray, compression: int, predictor: bool, pool.submit( _prepare_tile, data, tr, tc, th, tw, height, width, samples, dtype, bytes_per_sample, predictor, compression, + compression_level, ) for tr, tc in tile_indices ] @@ -855,6 +867,7 @@ def write(data: np.ndarray, path: str, *, crs_wkt: str | None = None, nodata=None, compression: str = 'zstd', + compression_level: int | None = None, tiled: bool = True, tile_size: int = 256, predictor: bool = False, @@ -914,9 +927,11 @@ def write(data: np.ndarray, path: str, *, # Full resolution if tiled: - rel_off, bc, comp_data = _write_tiled(data, comp_tag, predictor, tile_size) + rel_off, bc, comp_data = _write_tiled(data, comp_tag, predictor, tile_size, + compression_level=compression_level) else: - rel_off, bc, comp_data = _write_stripped(data, comp_tag, predictor) + rel_off, bc, comp_data = _write_stripped(data, comp_tag, predictor, + compression_level=compression_level) h, w = data.shape[:2] parts.append((data, w, h, rel_off, bc, comp_data)) @@ -938,9 +953,12 @@ def write(data: np.ndarray, path: str, *, current = _make_overview(current, method=overview_resampling) oh, ow = current.shape[:2] if tiled: - o_off, o_bc, o_data = _write_tiled(current, comp_tag, predictor, tile_size) + o_off, o_bc, o_data = _write_tiled(current, comp_tag, predictor, + tile_size, + compression_level=compression_level) else: - o_off, o_bc, o_data = _write_stripped(current, comp_tag, predictor) + o_off, o_bc, o_data = _write_stripped(current, comp_tag, predictor, + compression_level=compression_level) parts.append((current, ow, oh, o_off, o_bc, o_data)) file_bytes = _assemble_tiff( diff --git a/xrspatial/geotiff/tests/test_compression_level.py b/xrspatial/geotiff/tests/test_compression_level.py new file mode 100644 index 00000000..f778a92c --- /dev/null +++ b/xrspatial/geotiff/tests/test_compression_level.py @@ -0,0 +1,171 @@ +"""Tests for compression_level parameter in to_geotiff / write.""" +from __future__ import annotations + +import os + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_da(seed: int = 0, shape: tuple = (64, 64)) -> xr.DataArray: + """Return a small float32 DataArray with reproducible content.""" + rng = np.random.default_rng(seed) + arr = rng.standard_normal(shape).astype(np.float32) + return xr.DataArray(arr, dims=['y', 'x']) + + +def _write_read(da, tmp_path, **kwargs) -> tuple[xr.DataArray, int]: + """Write *da* to *tmp_path* with **kwargs, read back, return (da, file_size).""" + to_geotiff(da, tmp_path, **kwargs) + result = open_geotiff(tmp_path) + size = os.path.getsize(tmp_path) + return result, size + + +# --------------------------------------------------------------------------- +# Round-trip correctness +# --------------------------------------------------------------------------- + +class TestRoundTrip: + """Data survives write/read with various compression levels.""" + + def test_zstd_level_1(self, tmp_path): + da = _make_da(seed=1) + path = str(tmp_path / 'zstd1.tif') + result, _ = _write_read(da, path, compression='zstd', compression_level=1) + np.testing.assert_allclose(result.values, da.values) + + def test_zstd_level_22(self, tmp_path): + da = _make_da(seed=2) + path = str(tmp_path / 'zstd22.tif') + result, _ = _write_read(da, path, compression='zstd', compression_level=22) + np.testing.assert_allclose(result.values, da.values) + + def test_deflate_level_1(self, tmp_path): + da = _make_da(seed=3) + path = str(tmp_path / 'deflate1.tif') + result, _ = _write_read(da, path, compression='deflate', compression_level=1) + np.testing.assert_allclose(result.values, da.values) + + def test_deflate_level_9(self, tmp_path): + da = _make_da(seed=4) + path = str(tmp_path / 'deflate9.tif') + result, _ = _write_read(da, path, compression='deflate', compression_level=9) + np.testing.assert_allclose(result.values, da.values) + + +# --------------------------------------------------------------------------- +# Higher level → smaller file +# --------------------------------------------------------------------------- + +class TestLevelEffect: + """Higher compression level produces a smaller or equal file.""" + + def _make_compressible(self, shape=(128, 128)): + """Smooth, highly compressible float32 array.""" + rng = np.random.default_rng(42) + # Smooth gradient + small noise -- compresses well + y, x = np.mgrid[0:shape[0], 0:shape[1]] + arr = (y + x).astype(np.float32) + rng.standard_normal(shape).astype(np.float32) * 0.01 + return xr.DataArray(arr, dims=['y', 'x']) + + def test_zstd_higher_level_not_larger(self, tmp_path): + da = self._make_compressible() + path_lo = str(tmp_path / 'zstd_lo.tif') + path_hi = str(tmp_path / 'zstd_hi.tif') + to_geotiff(da, path_lo, compression='zstd', compression_level=1) + to_geotiff(da, path_hi, compression='zstd', compression_level=22) + size_lo = os.path.getsize(path_lo) + size_hi = os.path.getsize(path_hi) + assert size_hi <= size_lo, ( + f"Expected level-22 file ({size_hi}) <= level-1 file ({size_lo})") + + def test_deflate_higher_level_not_larger(self, tmp_path): + da = self._make_compressible() + path_lo = str(tmp_path / 'deflate_lo.tif') + path_hi = str(tmp_path / 'deflate_hi.tif') + to_geotiff(da, path_lo, compression='deflate', compression_level=1) + to_geotiff(da, path_hi, compression='deflate', compression_level=9) + size_lo = os.path.getsize(path_lo) + size_hi = os.path.getsize(path_hi) + assert size_hi <= size_lo, ( + f"Expected level-9 file ({size_hi}) <= level-1 file ({size_lo})") + + +# --------------------------------------------------------------------------- +# Default (level=None) uses codec default +# --------------------------------------------------------------------------- + +class TestDefaultLevel: + def test_none_uses_default_zstd(self, tmp_path): + da = _make_da(seed=5) + path = str(tmp_path / 'zstd_default.tif') + # Should not raise and should produce a valid file + result, size = _write_read(da, path, compression='zstd', compression_level=None) + assert size > 0 + np.testing.assert_allclose(result.values, da.values) + + def test_omitted_uses_default_deflate(self, tmp_path): + da = _make_da(seed=6) + path = str(tmp_path / 'deflate_default.tif') + # compression_level not passed at all -- should use codec default + result, size = _write_read(da, path, compression='deflate') + assert size > 0 + np.testing.assert_allclose(result.values, da.values) + + +# --------------------------------------------------------------------------- +# LZW ignores level silently +# --------------------------------------------------------------------------- + +class TestLZWIgnoresLevel: + def test_lzw_with_level_does_not_raise(self, tmp_path): + da = _make_da(seed=7) + path = str(tmp_path / 'lzw_level.tif') + # LZW has no level concept; passing one should succeed silently + result, size = _write_read(da, path, compression='lzw', compression_level=5) + assert size > 0 + np.testing.assert_allclose(result.values, da.values) + + +# --------------------------------------------------------------------------- +# Invalid levels raise ValueError +# --------------------------------------------------------------------------- + +class TestInvalidLevels: + def test_zstd_level_0_raises(self, tmp_path): + da = _make_da() + path = str(tmp_path / 'bad.tif') + with pytest.raises(ValueError, match='compression_level'): + to_geotiff(da, path, compression='zstd', compression_level=0) + + def test_zstd_level_23_raises(self, tmp_path): + da = _make_da() + path = str(tmp_path / 'bad.tif') + with pytest.raises(ValueError, match='compression_level'): + to_geotiff(da, path, compression='zstd', compression_level=23) + + def test_deflate_level_0_raises(self, tmp_path): + da = _make_da() + path = str(tmp_path / 'bad.tif') + with pytest.raises(ValueError, match='compression_level'): + to_geotiff(da, path, compression='deflate', compression_level=0) + + def test_deflate_level_10_raises(self, tmp_path): + da = _make_da() + path = str(tmp_path / 'bad.tif') + with pytest.raises(ValueError, match='compression_level'): + to_geotiff(da, path, compression='deflate', compression_level=10) + + def test_negative_level_raises(self, tmp_path): + da = _make_da() + path = str(tmp_path / 'bad.tif') + with pytest.raises(ValueError, match='compression_level'): + to_geotiff(da, path, compression='zstd', compression_level=-1) diff --git a/xrspatial/geotiff/tests/test_dtype_read.py b/xrspatial/geotiff/tests/test_dtype_read.py new file mode 100644 index 00000000..538aa882 --- /dev/null +++ b/xrspatial/geotiff/tests/test_dtype_read.py @@ -0,0 +1,117 @@ +"""Tests for dtype parameter on open_geotiff.""" +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +@pytest.fixture +def float64_tif(tmp_path): + """Write a float64 GeoTIFF for dtype cast tests.""" + arr = np.random.default_rng(99).random((80, 80)).astype(np.float64) + y = np.linspace(40.0, 41.0, 80) + x = np.linspace(-105.0, -104.0, 80) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + path = str(tmp_path / 'test_1083_f64.tif') + to_geotiff(da, path, compression='none') + return path, arr + + +@pytest.fixture +def uint16_tif(tmp_path): + """Write a uint16 GeoTIFF for dtype cast tests.""" + arr = np.random.default_rng(77).integers(0, 10000, (60, 60), + dtype=np.uint16) + y = np.linspace(40.0, 41.0, 60) + x = np.linspace(-105.0, -104.0, 60) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + path = str(tmp_path / 'test_1083_u16.tif') + to_geotiff(da, path, compression='none') + return path, arr + + +class TestDtypeEager: + def test_float64_to_float32(self, float64_tif): + path, orig = float64_tif + result = open_geotiff(path, dtype='float32') + assert result.dtype == np.float32 + np.testing.assert_array_almost_equal( + result.values, orig.astype(np.float32), decimal=6) + + def test_float64_to_float16(self, float64_tif): + path, orig = float64_tif + result = open_geotiff(path, dtype=np.float16) + assert result.dtype == np.float16 + + def test_uint16_to_int32(self, uint16_tif): + path, orig = uint16_tif + result = open_geotiff(path, dtype='int32') + assert result.dtype == np.int32 + np.testing.assert_array_equal(result.values, orig.astype(np.int32)) + + def test_uint16_to_uint8(self, uint16_tif): + path, _ = uint16_tif + result = open_geotiff(path, dtype='uint8') + assert result.dtype == np.uint8 + + def test_float_to_int_raises(self, float64_tif): + path, _ = float64_tif + with pytest.raises(ValueError, match='float.*int'): + open_geotiff(path, dtype='int32') + + def test_dtype_none_preserves_native(self, float64_tif): + path, _ = float64_tif + result = open_geotiff(path, dtype=None) + assert result.dtype == np.float64 + + + def test_int_with_nodata_float_to_int_raises(self, tmp_path): + """uint16 file with nodata: nodata masking promotes to float64, so float->int validation fires.""" + arr = np.array([[1, 2], [3, 9999]], dtype=np.uint16) + y = np.linspace(40.0, 41.0, 2) + x = np.linspace(-105.0, -104.0, 2) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326, 'nodata': 9999.0}) + path = str(tmp_path / 'test_1083_nodata_int_eager.tif') + to_geotiff(da, path, compression='none') + with pytest.raises(ValueError, match='float.*int'): + open_geotiff(path, dtype='int32') + + +class TestDtypeDask: + def test_float64_to_float32_dask(self, float64_tif): + path, orig = float64_tif + result = open_geotiff(path, dtype='float32', chunks=40) + assert result.dtype == np.float32 + computed = result.values + np.testing.assert_array_almost_equal( + computed, orig.astype(np.float32), decimal=6) + + def test_chunks_are_target_dtype(self, float64_tif): + path, _ = float64_tif + result = open_geotiff(path, dtype='float32', chunks=40) + assert result.data.dtype == np.float32 + + def test_float_to_int_raises_dask(self, float64_tif): + path, _ = float64_tif + with pytest.raises(ValueError, match='float.*int'): + open_geotiff(path, dtype='int32', chunks=40) + + def test_int_with_nodata_float_to_int_raises_dask(self, tmp_path): + """uint16 file with nodata: nodata masking promotes to float64, so float->int validation fires.""" + arr = np.array([[1, 2], [3, 9999]], dtype=np.uint16) + y = np.linspace(40.0, 41.0, 2) + x = np.linspace(-105.0, -104.0, 2) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326, 'nodata': 9999.0}) + path = str(tmp_path / 'test_1083_nodata_int_dask.tif') + to_geotiff(da, path, compression='none') + with pytest.raises(ValueError, match='float.*int'): + open_geotiff(path, dtype='int32', chunks=2) diff --git a/xrspatial/geotiff/tests/test_vrt_write.py b/xrspatial/geotiff/tests/test_vrt_write.py new file mode 100644 index 00000000..ab1b4e2c --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_write.py @@ -0,0 +1,111 @@ +"""Tests for VRT tiled output from to_geotiff.""" +import numpy as np +import os +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff + + +@pytest.fixture +def sample_raster(): + """200x200 float32 raster with coords and CRS.""" + arr = np.random.default_rng(55).random((200, 200), dtype=np.float32) + y = np.linspace(41.0, 40.0, 200) # north-to-south + x = np.linspace(-106.0, -105.0, 200) + da = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326, 'nodata': -9999.0}) + return da + + +class TestVrtOutputNumpy: + def test_creates_vrt_and_tiles_dir(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'out_1083.vrt') + to_geotiff(sample_raster, vrt_path) + assert os.path.exists(vrt_path) + tiles_dir = str(tmp_path / 'out_1083_tiles') + assert os.path.isdir(tiles_dir) + tile_files = os.listdir(tiles_dir) + assert len(tile_files) > 0 + assert all(f.endswith('.tif') for f in tile_files) + + def test_round_trip_numpy(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'rt_1083.vrt') + to_geotiff(sample_raster, vrt_path) + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_tile_naming_convention(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'named_1083.vrt') + to_geotiff(sample_raster, vrt_path, tile_size=100) + tiles_dir = str(tmp_path / 'named_1083_tiles') + files = sorted(os.listdir(tiles_dir)) + # 200x200 with tile_size=100 -> 2x2 grid + assert files == [ + 'tile_00_00.tif', 'tile_00_01.tif', + 'tile_01_00.tif', 'tile_01_01.tif', + ] + + def test_relative_paths_in_vrt(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'rel_1083.vrt') + to_geotiff(sample_raster, vrt_path) + with open(vrt_path) as f: + content = f.read() + # Paths should be relative (no leading /) + assert 'rel_1083_tiles/' in content + assert str(tmp_path) not in content + + def test_compression_level_passed_to_tiles(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'cl_1083.vrt') + to_geotiff(sample_raster, vrt_path, compression='zstd', + compression_level=1) + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + +class TestVrtOutputDask: + def test_dask_round_trip(self, sample_raster, tmp_path): + dask_da = sample_raster.chunk({'y': 100, 'x': 100}) + vrt_path = str(tmp_path / 'dask_1083.vrt') + to_geotiff(dask_da, vrt_path) + result = open_geotiff(vrt_path) + np.testing.assert_array_almost_equal( + result.values, sample_raster.values, decimal=5) + + def test_dask_one_tile_per_chunk(self, sample_raster, tmp_path): + dask_da = sample_raster.chunk({'y': 100, 'x': 100}) + vrt_path = str(tmp_path / 'chunks_1083.vrt') + to_geotiff(dask_da, vrt_path) + tiles_dir = str(tmp_path / 'chunks_1083_tiles') + # 200x200 chunked 100x100 -> 2x2 = 4 tiles + assert len(os.listdir(tiles_dir)) == 4 + + +class TestVrtEdgeCases: + def test_cog_with_vrt_raises(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'cog_1083.vrt') + with pytest.raises(ValueError, match='cog.*vrt|vrt.*cog|COG.*VRT|VRT.*COG|cog.*VRT|vrt.*COG'): + to_geotiff(sample_raster, vrt_path, cog=True) + + def test_overview_levels_with_vrt_raises(self, sample_raster, tmp_path): + vrt_path = str(tmp_path / 'ovr_1083.vrt') + with pytest.raises(ValueError, match='overview.*vrt|vrt.*overview|overview.*VRT|VRT.*overview'): + to_geotiff(sample_raster, vrt_path, overview_levels=[2, 4]) + + def test_nonempty_tiles_dir_raises(self, sample_raster, tmp_path): + tiles_dir = tmp_path / 'exist_1083_tiles' + tiles_dir.mkdir() + (tiles_dir / 'dummy.tif').write_text('x') + vrt_path = str(tmp_path / 'exist_1083.vrt') + with pytest.raises(FileExistsError): + to_geotiff(sample_raster, vrt_path) + + def test_empty_tiles_dir_ok(self, sample_raster, tmp_path): + tiles_dir = tmp_path / 'empty_1083_tiles' + tiles_dir.mkdir() + vrt_path = str(tmp_path / 'empty_1083.vrt') + to_geotiff(sample_raster, vrt_path) + assert os.path.exists(vrt_path)