|
| 1 | +"""Datum shift grid loading and interpolation. |
| 2 | +
|
| 3 | +Downloads horizontal offset grids from the PROJ CDN, caches them locally, |
| 4 | +and provides Numba JIT bilinear interpolation for per-pixel datum shifts. |
| 5 | +
|
| 6 | +Grid format: GeoTIFF with 2+ bands: |
| 7 | + Band 1: latitude offset (arc-seconds) |
| 8 | + Band 2: longitude offset (arc-seconds) |
| 9 | +""" |
| 10 | +from __future__ import annotations |
| 11 | + |
| 12 | +import math |
| 13 | +import os |
| 14 | +import urllib.request |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +from numba import njit, prange |
| 18 | + |
| 19 | +_PROJ_CDN = "https://cdn.proj.org" |
| 20 | + |
| 21 | +# Vendored grid directory (shipped with the package) |
| 22 | +_VENDORED_DIR = os.path.join(os.path.dirname(__file__), 'grids') |
| 23 | + |
| 24 | +# Grid registry: key -> (filename, coverage bounds, description, cdn_url) |
| 25 | +# Bounds are (lon_min, lat_min, lon_max, lat_max). |
| 26 | +GRID_REGISTRY = { |
| 27 | + 'NAD27_CONUS': ( |
| 28 | + 'us_noaa_conus.tif', |
| 29 | + (-131, 20, -63, 50), |
| 30 | + 'NAD27->NAD83 CONUS (NADCON)', |
| 31 | + f'{_PROJ_CDN}/us_noaa_conus.tif', |
| 32 | + ), |
| 33 | + 'NAD27_NADCON5_CONUS': ( |
| 34 | + 'us_noaa_nadcon5_nad27_nad83_1986_conus.tif', |
| 35 | + (-125, 24, -66, 50), |
| 36 | + 'NAD27->NAD83 CONUS (NADCON5)', |
| 37 | + f'{_PROJ_CDN}/us_noaa_nadcon5_nad27_nad83_1986_conus.tif', |
| 38 | + ), |
| 39 | +} |
| 40 | + |
| 41 | +# Cache directory for grids not vendored |
| 42 | +_CACHE_DIR = os.path.join(os.path.expanduser('~'), '.cache', 'xrspatial', 'proj_grids') |
| 43 | + |
| 44 | + |
| 45 | +def _ensure_cache_dir(): |
| 46 | + os.makedirs(_CACHE_DIR, exist_ok=True) |
| 47 | + |
| 48 | + |
| 49 | +def _find_grid_file(filename, cdn_url=None): |
| 50 | + """Find a grid file: check vendored dir first, then cache, then download.""" |
| 51 | + # 1. Vendored (shipped with package) |
| 52 | + vendored = os.path.join(_VENDORED_DIR, filename) |
| 53 | + if os.path.exists(vendored): |
| 54 | + return vendored |
| 55 | + |
| 56 | + # 2. User cache |
| 57 | + cached = os.path.join(_CACHE_DIR, filename) |
| 58 | + if os.path.exists(cached): |
| 59 | + return cached |
| 60 | + |
| 61 | + # 3. Download from CDN |
| 62 | + if cdn_url: |
| 63 | + _ensure_cache_dir() |
| 64 | + urllib.request.urlretrieve(cdn_url, cached) |
| 65 | + return cached |
| 66 | + |
| 67 | + return None |
| 68 | + |
| 69 | + |
| 70 | +def load_grid(grid_key): |
| 71 | + """Load a datum shift grid by registry key. |
| 72 | +
|
| 73 | + Returns (dlat, dlon, bounds, resolution) where: |
| 74 | + - dlat, dlon: numpy float64 arrays (arc-seconds), shape (H, W) |
| 75 | + - bounds: (left, bottom, right, top) in degrees |
| 76 | + - resolution: (res_lon, res_lat) in degrees |
| 77 | + """ |
| 78 | + if grid_key not in GRID_REGISTRY: |
| 79 | + return None |
| 80 | + |
| 81 | + filename, _, _, cdn_url = GRID_REGISTRY[grid_key] |
| 82 | + path = _find_grid_file(filename, cdn_url) |
| 83 | + if path is None: |
| 84 | + return None |
| 85 | + |
| 86 | + # Read with rasterio for correct multi-band handling |
| 87 | + try: |
| 88 | + import rasterio |
| 89 | + with rasterio.open(path) as ds: |
| 90 | + dlat = ds.read(1).astype(np.float64) # arc-seconds |
| 91 | + dlon = ds.read(2).astype(np.float64) # arc-seconds |
| 92 | + b = ds.bounds |
| 93 | + bounds = (b.left, b.bottom, b.right, b.top) |
| 94 | + res = ds.res # (res_y, res_x) in degrees |
| 95 | + return dlat, dlon, bounds, (res[1], res[0]) |
| 96 | + except ImportError: |
| 97 | + pass |
| 98 | + |
| 99 | + # Fallback: read with our own reader (may need band axis handling) |
| 100 | + from xrspatial.geotiff import read_geotiff |
| 101 | + da = read_geotiff(path) |
| 102 | + data = da.values |
| 103 | + if data.ndim == 3: |
| 104 | + # (H, W, bands) or (bands, H, W) |
| 105 | + if data.shape[2] == 2: |
| 106 | + dlat = data[:, :, 0].astype(np.float64) |
| 107 | + dlon = data[:, :, 1].astype(np.float64) |
| 108 | + else: |
| 109 | + dlat = data[0].astype(np.float64) |
| 110 | + dlon = data[1].astype(np.float64) |
| 111 | + else: |
| 112 | + return None |
| 113 | + |
| 114 | + y_coords = da.coords['y'].values |
| 115 | + x_coords = da.coords['x'].values |
| 116 | + bounds = (float(x_coords[0]), float(y_coords[-1]), |
| 117 | + float(x_coords[-1]), float(y_coords[0])) |
| 118 | + res_x = abs(float(x_coords[1] - x_coords[0])) if len(x_coords) > 1 else 0.25 |
| 119 | + res_y = abs(float(y_coords[1] - y_coords[0])) if len(y_coords) > 1 else 0.25 |
| 120 | + return dlat, dlon, bounds, (res_x, res_y) |
| 121 | + |
| 122 | + |
| 123 | +# --------------------------------------------------------------------------- |
| 124 | +# Numba bilinear grid interpolation |
| 125 | +# --------------------------------------------------------------------------- |
| 126 | + |
| 127 | +@njit(nogil=True, cache=True) |
| 128 | +def _grid_interp_point(lon, lat, dlat_grid, dlon_grid, |
| 129 | + grid_left, grid_top, grid_res_x, grid_res_y, |
| 130 | + grid_h, grid_w): |
| 131 | + """Bilinear interpolation of a single point in the shift grid. |
| 132 | +
|
| 133 | + Returns (dlat_arcsec, dlon_arcsec) or (0, 0) if outside the grid. |
| 134 | + """ |
| 135 | + col_f = (lon - grid_left) / grid_res_x |
| 136 | + row_f = (grid_top - lat) / grid_res_y |
| 137 | + |
| 138 | + if col_f < 0 or col_f > grid_w - 1 or row_f < 0 or row_f > grid_h - 1: |
| 139 | + return 0.0, 0.0 |
| 140 | + |
| 141 | + c0 = int(col_f) |
| 142 | + r0 = int(row_f) |
| 143 | + if c0 >= grid_w - 1: |
| 144 | + c0 = grid_w - 2 |
| 145 | + if r0 >= grid_h - 1: |
| 146 | + r0 = grid_h - 2 |
| 147 | + |
| 148 | + dc = col_f - c0 |
| 149 | + dr = row_f - r0 |
| 150 | + |
| 151 | + w00 = (1.0 - dr) * (1.0 - dc) |
| 152 | + w01 = (1.0 - dr) * dc |
| 153 | + w10 = dr * (1.0 - dc) |
| 154 | + w11 = dr * dc |
| 155 | + |
| 156 | + dlat = (dlat_grid[r0, c0] * w00 + dlat_grid[r0, c0 + 1] * w01 + |
| 157 | + dlat_grid[r0 + 1, c0] * w10 + dlat_grid[r0 + 1, c0 + 1] * w11) |
| 158 | + dlon = (dlon_grid[r0, c0] * w00 + dlon_grid[r0, c0 + 1] * w01 + |
| 159 | + dlon_grid[r0 + 1, c0] * w10 + dlon_grid[r0 + 1, c0 + 1] * w11) |
| 160 | + |
| 161 | + return dlat, dlon |
| 162 | + |
| 163 | + |
| 164 | +@njit(nogil=True, cache=True, parallel=True) |
| 165 | +def apply_grid_shift_forward(lon_arr, lat_arr, dlat_grid, dlon_grid, |
| 166 | + grid_left, grid_top, grid_res_x, grid_res_y, |
| 167 | + grid_h, grid_w): |
| 168 | + """Apply grid-based datum shift: source -> target (add offsets).""" |
| 169 | + for i in prange(lon_arr.shape[0]): |
| 170 | + dlat, dlon = _grid_interp_point( |
| 171 | + lon_arr[i], lat_arr[i], dlat_grid, dlon_grid, |
| 172 | + grid_left, grid_top, grid_res_x, grid_res_y, |
| 173 | + grid_h, grid_w, |
| 174 | + ) |
| 175 | + lat_arr[i] += dlat / 3600.0 # arc-seconds to degrees |
| 176 | + lon_arr[i] += dlon / 3600.0 |
| 177 | + |
| 178 | + |
| 179 | +@njit(nogil=True, cache=True, parallel=True) |
| 180 | +def apply_grid_shift_inverse(lon_arr, lat_arr, dlat_grid, dlon_grid, |
| 181 | + grid_left, grid_top, grid_res_x, grid_res_y, |
| 182 | + grid_h, grid_w): |
| 183 | + """Apply inverse grid-based datum shift: target -> source (subtract offsets). |
| 184 | +
|
| 185 | + Uses iterative approach: the grid is indexed by source coordinates, |
| 186 | + but we have target coordinates. One iteration is usually sufficient |
| 187 | + since the shifts are small relative to the grid spacing. |
| 188 | + """ |
| 189 | + for i in prange(lon_arr.shape[0]): |
| 190 | + # Initial estimate: subtract the shift at the target coords |
| 191 | + dlat, dlon = _grid_interp_point( |
| 192 | + lon_arr[i], lat_arr[i], dlat_grid, dlon_grid, |
| 193 | + grid_left, grid_top, grid_res_x, grid_res_y, |
| 194 | + grid_h, grid_w, |
| 195 | + ) |
| 196 | + lon_est = lon_arr[i] - dlon / 3600.0 |
| 197 | + lat_est = lat_arr[i] - dlat / 3600.0 |
| 198 | + |
| 199 | + # Refine: re-interpolate at the estimated source coords |
| 200 | + dlat2, dlon2 = _grid_interp_point( |
| 201 | + lon_est, lat_est, dlat_grid, dlon_grid, |
| 202 | + grid_left, grid_top, grid_res_x, grid_res_y, |
| 203 | + grid_h, grid_w, |
| 204 | + ) |
| 205 | + lon_arr[i] -= dlon2 / 3600.0 |
| 206 | + lat_arr[i] -= dlat2 / 3600.0 |
| 207 | + |
| 208 | + |
| 209 | +# --------------------------------------------------------------------------- |
| 210 | +# Grid cache (loaded grids, keyed by grid_key) |
| 211 | +# --------------------------------------------------------------------------- |
| 212 | + |
| 213 | +_loaded_grids = {} |
| 214 | + |
| 215 | + |
| 216 | +def get_grid(grid_key): |
| 217 | + """Get a loaded grid, downloading if necessary. |
| 218 | +
|
| 219 | + Returns (dlat, dlon, left, top, res_x, res_y, h, w) or None. |
| 220 | + """ |
| 221 | + if grid_key in _loaded_grids: |
| 222 | + return _loaded_grids[grid_key] |
| 223 | + |
| 224 | + result = load_grid(grid_key) |
| 225 | + if result is None: |
| 226 | + _loaded_grids[grid_key] = None |
| 227 | + return None |
| 228 | + |
| 229 | + dlat, dlon, bounds, (res_x, res_y) = result |
| 230 | + h, w = dlat.shape |
| 231 | + # Ensure contiguous float64 for Numba |
| 232 | + dlat = np.ascontiguousarray(dlat, dtype=np.float64) |
| 233 | + dlon = np.ascontiguousarray(dlon, dtype=np.float64) |
| 234 | + entry = (dlat, dlon, bounds[0], bounds[3], res_x, res_y, h, w) |
| 235 | + _loaded_grids[grid_key] = entry |
| 236 | + return entry |
| 237 | + |
| 238 | + |
| 239 | +def find_grid_for_point(lon, lat, datum_key): |
| 240 | + """Find the best grid covering a given point. |
| 241 | +
|
| 242 | + Returns the grid_key or None. |
| 243 | + """ |
| 244 | + # Map datum names to grid keys, ordered by preference |
| 245 | + datum_grids = { |
| 246 | + 'NAD27': ['NAD27_NADCON5_CONUS', 'NAD27_CONUS'], |
| 247 | + 'clarke66': ['NAD27_NADCON5_CONUS', 'NAD27_CONUS'], |
| 248 | + } |
| 249 | + |
| 250 | + candidates = datum_grids.get(datum_key, []) |
| 251 | + for grid_key in candidates: |
| 252 | + entry = GRID_REGISTRY.get(grid_key) |
| 253 | + if entry is None: |
| 254 | + continue |
| 255 | + _, coverage, _, _ = entry |
| 256 | + lon_min, lat_min, lon_max, lat_max = coverage |
| 257 | + if lon_min <= lon <= lon_max and lat_min <= lat <= lat_max: |
| 258 | + return grid_key |
| 259 | + return None |
0 commit comments