Skip to content

Commit 2d8e18f

Browse files
committed
Fixes #903: add comprehensive input validation across public API
Add _validate_raster() and _validate_scalar() utilities to utils.py and call them at the top of every public function (~35 entry points across 15 modules). Invalid input now raises clear TypeError/ValueError with function name, parameter name, and expected-vs-actual values instead of cryptic numba TypingError deep in JIT kernels. - _validate_raster: checks DataArray type, ndim, numeric/integer dtype - _validate_scalar: checks scalar type (incl. numpy equivalents), range - Replace ad-hoc isinstance/ndim checks in focal, zonal, cost_distance - Add validation to all surface, proximity, classify, multispectral, viewshed, perlin, terrain, and bump functions - Add test_validation.py with 74 systematic tests covering type, ndim, dtype, scalar range, and integer-only checks
1 parent fe7921a commit 2d8e18f

File tree

17 files changed

+636
-88
lines changed

17 files changed

+636
-88
lines changed

xrspatial/aspect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from xrspatial.utils import _extract_latlon_coords
2121
from xrspatial.utils import _pad_array
2222
from xrspatial.utils import _validate_boundary
23+
from xrspatial.utils import _validate_raster
2324
from xrspatial.utils import cuda_args
2425
from xrspatial.utils import ngjit
2526
from xrspatial.dataset_support import supports_dataset
@@ -396,6 +397,8 @@ def aspect(agg: xr.DataArray,
396397
>>> aspect_agg = aspect(raster)
397398
"""
398399

400+
_validate_raster(agg, func_name='aspect', name='agg')
401+
399402
if method not in ('planar', 'geodesic'):
400403
raise ValueError(
401404
f"method must be 'planar' or 'geodesic', got {method!r}"

xrspatial/bump.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import xarray as xr
55
from xarray import DataArray
66

7-
from xrspatial.utils import ngjit
7+
from xrspatial.utils import _validate_scalar, ngjit
88

99
# TODO: change parameters to take agg instead of height / width
1010

@@ -194,6 +194,9 @@ def heights(locations, src, src_range, height = 20):
194194
Description: Example Bump Map
195195
units: km
196196
"""
197+
_validate_scalar(width, func_name='bump', name='width', dtype=int, min_val=1)
198+
_validate_scalar(height, func_name='bump', name='height', dtype=int, min_val=1)
199+
197200
linx = range(width)
198201
liny = range(height)
199202

xrspatial/classify.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,14 @@ class cupy(object):
2424
import numba as nb
2525
import numpy as np
2626

27-
from xrspatial.utils import ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func
27+
from xrspatial.utils import (
28+
ArrayTypeFunctionMapping,
29+
_validate_raster,
30+
_validate_scalar,
31+
cuda_args,
32+
ngjit,
33+
not_implemented_func,
34+
)
2835
from xrspatial.dataset_support import supports_dataset
2936

3037

@@ -136,6 +143,7 @@ def binary(agg, values, name='binary'):
136143
[0., 0., 0., 0., np.nan]], dtype=float32)
137144
Dimensions without coordinates: dim_0, dim_1
138145
"""
146+
_validate_raster(agg, func_name='binary', name='agg', ndim=None)
139147

140148
mapper = ArrayTypeFunctionMapping(numpy_func=_run_numpy_binary,
141149
dask_func=_run_dask_numpy_binary,
@@ -380,6 +388,7 @@ def reclassify(agg: xr.DataArray,
380388
381389
Reclassify works with Dask with CuPy backed xarray DataArray.
382390
"""
391+
_validate_raster(agg, func_name='reclassify', name='agg', ndim=None)
383392

384393
if len(bins) != len(new_values):
385394
raise ValueError(
@@ -515,6 +524,8 @@ def quantile(agg: xr.DataArray,
515524
Attributes:
516525
res: (10.0, 10.0)
517526
"""
527+
_validate_raster(agg, func_name='quantile', name='agg', ndim=None)
528+
_validate_scalar(k, func_name='quantile', name='k', dtype=int, min_val=2)
518529

519530
q = _quantile(agg, num_sample, k)
520531
k_q = q.shape[0]
@@ -836,6 +847,8 @@ def natural_breaks(agg: xr.DataArray,
836847
[ 4., 4., 4., 4., nan]], dtype=float32)
837848
Dimensions without coordinates: dim_0, dim_1
838849
"""
850+
_validate_raster(agg, func_name='natural_breaks', name='agg', ndim=None)
851+
_validate_scalar(k, func_name='natural_breaks', name='k', dtype=int, min_val=2)
839852

840853
mapper = ArrayTypeFunctionMapping(
841854
numpy_func=_run_natural_break,
@@ -944,6 +957,8 @@ def equal_interval(agg: xr.DataArray,
944957
Attributes:
945958
res: (10.0, 10.0)
946959
"""
960+
_validate_raster(agg, func_name='equal_interval', name='agg', ndim=None)
961+
_validate_scalar(k, func_name='equal_interval', name='k', dtype=int, min_val=1)
947962

948963
mapper = ArrayTypeFunctionMapping(
949964
numpy_func=lambda *args: _run_equal_interval(*args, module=np),
@@ -1015,6 +1030,8 @@ def std_mean(agg: xr.DataArray,
10151030
----------
10161031
- PySAL: https://pysal.org/mapclassify/_modules/mapclassify/classifiers.html#StdMean
10171032
"""
1033+
_validate_raster(agg, func_name='std_mean', name='agg', ndim=None)
1034+
10181035
mapper = ArrayTypeFunctionMapping(
10191036
numpy_func=lambda *args: _run_std_mean(*args, module=np),
10201037
dask_func=lambda *args: _run_std_mean(*args, module=da),
@@ -1112,6 +1129,8 @@ def head_tail_breaks(agg: xr.DataArray,
11121129
----------
11131130
- PySAL: https://pysal.org/mapclassify/_modules/mapclassify/classifiers.html#HeadTailBreaks
11141131
"""
1132+
_validate_raster(agg, func_name='head_tail_breaks', name='agg', ndim=None)
1133+
11151134
mapper = ArrayTypeFunctionMapping(
11161135
numpy_func=lambda *args: _run_head_tail_breaks(*args, module=np),
11171136
dask_func=_run_dask_head_tail_breaks,
@@ -1191,6 +1210,8 @@ def percentiles(agg: xr.DataArray,
11911210
----------
11921211
- PySAL: https://pysal.org/mapclassify/_modules/mapclassify/classifiers.html#Percentiles
11931212
"""
1213+
_validate_raster(agg, func_name='percentiles', name='agg', ndim=None)
1214+
11941215
if pct is None:
11951216
pct = [1, 10, 50, 90, 99]
11961217

@@ -1328,6 +1349,9 @@ def maximum_breaks(agg: xr.DataArray,
13281349
----------
13291350
- PySAL: https://pysal.org/mapclassify/_modules/mapclassify/classifiers.html#MaximumBreaks
13301351
"""
1352+
_validate_raster(agg, func_name='maximum_breaks', name='agg', ndim=None)
1353+
_validate_scalar(k, func_name='maximum_breaks', name='k', dtype=int, min_val=2)
1354+
13311355
mapper = ArrayTypeFunctionMapping(
13321356
numpy_func=lambda *args: _run_maximum_breaks(*args, module=np),
13331357
dask_func=_run_dask_maximum_breaks,
@@ -1431,6 +1455,8 @@ def box_plot(agg: xr.DataArray,
14311455
----------
14321456
- PySAL: https://pysal.org/mapclassify/_modules/mapclassify/classifiers.html#BoxPlot
14331457
"""
1458+
_validate_raster(agg, func_name='box_plot', name='agg', ndim=None)
1459+
14341460
mapper = ArrayTypeFunctionMapping(
14351461
numpy_func=lambda *args: _run_box_plot(*args, module=np),
14361462
dask_func=lambda *args: _run_box_plot(*args, module=da),

xrspatial/cost_distance.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class cupy: # type: ignore[no-redef]
4949
ndarray = False
5050

5151
from xrspatial.utils import (
52+
_validate_raster,
5253
cuda_args, get_dataarray_resolution, ngjit,
5354
has_cuda_and_cupy, is_cupy_array, is_dask_cupy,
5455
)
@@ -1098,10 +1099,8 @@ def cost_distance(
10981099
Source pixels have cost 0. Unreachable pixels are NaN.
10991100
"""
11001101
# --- validation ---
1101-
if raster.ndim != 2:
1102-
raise ValueError("raster must be 2-D")
1103-
if friction.ndim != 2:
1104-
raise ValueError("friction must be 2-D")
1102+
_validate_raster(raster, func_name='cost_distance', name='raster')
1103+
_validate_raster(friction, func_name='cost_distance', name='friction')
11051104
if raster.shape != friction.shape:
11061105
raise ValueError("raster and friction must have the same shape")
11071106
if raster.dims != (y, x):

xrspatial/curvature.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class cupy(object):
2525
from xrspatial.utils import _boundary_to_dask
2626
from xrspatial.utils import _pad_array
2727
from xrspatial.utils import _validate_boundary
28+
from xrspatial.utils import _validate_raster
2829
from xrspatial.utils import cuda_args
2930
from xrspatial.utils import get_dataarray_resolution
3031
from xrspatial.utils import ngjit
@@ -249,6 +250,8 @@ def curvature(agg: xr.DataArray,
249250
Attributes:
250251
res: (10, 10)
251252
"""
253+
_validate_raster(agg, func_name='curvature', name='agg')
254+
252255
cellsize_x, cellsize_y = get_dataarray_resolution(agg)
253256
cellsize = (cellsize_x + cellsize_y) / 2
254257

xrspatial/focal.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class cupy(object):
2929

3030
from xrspatial.convolution import convolve_2d, custom_kernel, _convolve_2d_numpy
3131
from xrspatial.utils import (ArrayTypeFunctionMapping, _boundary_to_dask, _pad_array,
32-
_validate_boundary, cuda_args, ngjit, not_implemented_func)
32+
_validate_boundary, _validate_raster, _validate_scalar,
33+
cuda_args, ngjit, not_implemented_func)
3334
from xrspatial.dataset_support import supports_dataset
3435

3536
# TODO: Make convolution more generic with numba first-class functions.
@@ -278,6 +279,8 @@ def mean(agg, passes=1, excludes=[np.nan], name='mean', boundary='nan'):
278279
Dimensions without coordinates: dim_0, dim_1
279280
"""
280281

282+
_validate_raster(agg, func_name='mean', name='agg')
283+
_validate_scalar(passes, func_name='mean', name='passes', dtype=int, min_val=1)
281284
_validate_boundary(boundary)
282285
out = agg.data.astype(float)
283286
for i in range(passes):
@@ -507,12 +510,7 @@ def apply(raster, kernel, func=_calc_mean, name='focal_apply', boundary='nan'):
507510
[2. , 2. , 2. , 1.5]])
508511
Dimensions without coordinates: y, x
509512
"""
510-
# validate raster
511-
if not isinstance(raster, DataArray):
512-
raise TypeError("`raster` must be instance of DataArray")
513-
514-
if raster.ndim != 2:
515-
raise ValueError("`raster` must be 2D")
513+
_validate_raster(raster, func_name='apply', name='raster')
516514

517515
# Validate the kernel
518516
kernel = custom_kernel(kernel)
@@ -957,12 +955,7 @@ def focal_stats(agg,
957955
* stats (stats) object 'min' 'sum'
958956
Dimensions without coordinates: dim_0, dim_1
959957
"""
960-
# validate raster
961-
if not isinstance(agg, DataArray):
962-
raise TypeError("`agg` must be instance of DataArray")
963-
964-
if agg.ndim != 2:
965-
raise ValueError("`agg` must be 2D")
958+
_validate_raster(agg, func_name='focal_stats', name='agg')
966959

967960
# Validate the kernel
968961
kernel = custom_kernel(kernel)
@@ -1207,12 +1200,7 @@ def hotspots(raster, kernel, boundary='nan'):
12071200
Dimensions without coordinates: dim_0, dim_1
12081201
"""
12091202

1210-
# validate raster
1211-
if not isinstance(raster, DataArray):
1212-
raise TypeError("`raster` must be instance of DataArray")
1213-
1214-
if raster.ndim != 2:
1215-
raise ValueError("`raster` must be 2D")
1203+
_validate_raster(raster, func_name='hotspots', name='raster')
12161204

12171205
_validate_boundary(boundary)
12181206

xrspatial/hillshade.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from .gpu_rtx import has_rtx
1616
from .utils import (_boundary_to_dask, _pad_array, _validate_boundary,
17+
_validate_raster, _validate_scalar,
1718
calc_cuda_dims, get_dataarray_resolution,
1819
has_cuda_and_cupy, is_cupy_array, is_cupy_backed)
1920
from .dataset_support import supports_dataset
@@ -217,6 +218,10 @@ def hillshade(agg: xr.DataArray,
217218
>>> hillshade_agg = hillshade(raster)
218219
"""
219220

221+
_validate_raster(agg, func_name='hillshade', name='agg')
222+
_validate_scalar(azimuth, func_name='hillshade', name='azimuth', min_val=0, max_val=360)
223+
_validate_scalar(angle_altitude, func_name='hillshade', name='angle_altitude', min_val=0, max_val=90)
224+
220225
if shadows and not has_rtx():
221226
raise RuntimeError(
222227
"Can only calculate shadows if cupy and rtxpy are available")

xrspatial/multispectral.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from numba import cuda
1010
from xarray import DataArray
1111

12-
from xrspatial.utils import (ArrayTypeFunctionMapping, cuda_args, ngjit, not_implemented_func,
13-
validate_arrays)
12+
from xrspatial.utils import (ArrayTypeFunctionMapping, _validate_raster, cuda_args, ngjit,
13+
not_implemented_func, validate_arrays)
1414
from xrspatial.dataset_support import supports_dataset_bands
1515

1616
# 3rd-party
@@ -153,6 +153,10 @@ def arvi(nir_agg: xr.DataArray,
153153
[ 0.02488688 0.00816024 0.00068681 0.02650602]]
154154
"""
155155

156+
_validate_raster(nir_agg, func_name='arvi', name='nir_agg')
157+
_validate_raster(red_agg, func_name='arvi', name='red_agg')
158+
_validate_raster(blue_agg, func_name='arvi', name='blue_agg')
159+
156160
validate_arrays(red_agg, nir_agg, blue_agg)
157161

158162
mapper = ArrayTypeFunctionMapping(numpy_func=_arvi_cpu,
@@ -312,6 +316,10 @@ def evi(nir_agg: xr.DataArray,
312316
[-8.53211 5.486726 0.8394608 3.5043988]]
313317
"""
314318

319+
_validate_raster(nir_agg, func_name='evi', name='nir_agg')
320+
_validate_raster(red_agg, func_name='evi', name='red_agg')
321+
_validate_raster(blue_agg, func_name='evi', name='blue_agg')
322+
315323
if not red_agg.shape == nir_agg.shape == blue_agg.shape:
316324
raise ValueError("input layers expected to have equal shapes")
317325

@@ -456,6 +464,9 @@ def gci(nir_agg: xr.DataArray,
456464
[0.34822243 0.28270411 0.29641694 0.359375 ]]
457465
"""
458466

467+
_validate_raster(nir_agg, func_name='gci', name='nir_agg')
468+
_validate_raster(green_agg, func_name='gci', name='green_agg')
469+
459470
validate_arrays(nir_agg, green_agg)
460471

461472
mapper = ArrayTypeFunctionMapping(numpy_func=_gci_cpu,
@@ -540,6 +551,9 @@ def nbr(nir_agg: xr.DataArray,
540551
[-0.10823033 -0.14486392 -0.12981689 -0.12121212]]
541552
"""
542553

554+
_validate_raster(nir_agg, func_name='nbr', name='nir_agg')
555+
_validate_raster(swir2_agg, func_name='nbr', name='swir2_agg')
556+
543557
validate_arrays(nir_agg, swir2_agg)
544558

545559
mapper = ArrayTypeFunctionMapping(
@@ -631,6 +645,9 @@ def nbr2(swir1_agg: xr.DataArray,
631645
[0.07218576 0.06857143 0.067659 0.07520281]]
632646
"""
633647

648+
_validate_raster(swir1_agg, func_name='nbr2', name='swir1_agg')
649+
_validate_raster(swir2_agg, func_name='nbr2', name='swir2_agg')
650+
634651
validate_arrays(swir1_agg, swir2_agg)
635652

636653
mapper = ArrayTypeFunctionMapping(
@@ -715,6 +732,9 @@ def ndvi(nir_agg: xr.DataArray,
715732
[0.06709956 0.04431737 0.04496226 0.07792632]]
716733
"""
717734

735+
_validate_raster(nir_agg, func_name='ndvi', name='nir_agg')
736+
_validate_raster(red_agg, func_name='ndvi', name='red_agg')
737+
718738
validate_arrays(nir_agg, red_agg)
719739

720740
mapper = ArrayTypeFunctionMapping(
@@ -804,6 +824,9 @@ def ndmi(nir_agg: xr.DataArray,
804824
[-0.17901748 -0.21133603 -0.19575651 -0.19464068]]
805825
"""
806826

827+
_validate_raster(nir_agg, func_name='ndmi', name='nir_agg')
828+
_validate_raster(swir1_agg, func_name='ndmi', name='swir1_agg')
829+
807830
validate_arrays(nir_agg, swir1_agg)
808831

809832
mapper = ArrayTypeFunctionMapping(
@@ -994,6 +1017,9 @@ def savi(nir_agg: xr.DataArray,
9941017
[0.03353769 0.02215077 0.02247375 0.03895046]]
9951018
"""
9961019

1020+
_validate_raster(nir_agg, func_name='savi', name='nir_agg')
1021+
_validate_raster(red_agg, func_name='savi', name='red_agg')
1022+
9971023
validate_arrays(red_agg, nir_agg)
9981024

9991025
if not -1.0 <= soil_factor <= 1.0:
@@ -1138,6 +1164,10 @@ def sipi(nir_agg: xr.DataArray,
11381164
[1.2903225 1.6451613 1.9708029 1.3556485]]
11391165
"""
11401166

1167+
_validate_raster(nir_agg, func_name='sipi', name='nir_agg')
1168+
_validate_raster(red_agg, func_name='sipi', name='red_agg')
1169+
_validate_raster(blue_agg, func_name='sipi', name='blue_agg')
1170+
11411171
validate_arrays(red_agg, nir_agg, blue_agg)
11421172

11431173
mapper = ArrayTypeFunctionMapping(numpy_func=_sipi_cpu,
@@ -1314,6 +1344,10 @@ def ebbi(red_agg: xr.DataArray,
13141344
* lon (lon) float64 0.0 1.0 2.0 3.0
13151345
"""
13161346

1347+
_validate_raster(red_agg, func_name='ebbi', name='red_agg')
1348+
_validate_raster(swir_agg, func_name='ebbi', name='swir_agg')
1349+
_validate_raster(tir_agg, func_name='ebbi', name='tir_agg')
1350+
13171351
validate_arrays(red_agg, swir_agg, tir_agg)
13181352

13191353
mapper = ArrayTypeFunctionMapping(numpy_func=_ebbi_cpu,
@@ -1515,6 +1549,10 @@ def true_color(r, g, b, nodata=1, c=10.0, th=0.125, name='true_color'):
15151549
>>> true_color_img.plot.imshow()
15161550
"""
15171551

1552+
_validate_raster(r, func_name='true_color', name='r')
1553+
_validate_raster(g, func_name='true_color', name='g')
1554+
_validate_raster(b, func_name='true_color', name='b')
1555+
15181556
mapper = ArrayTypeFunctionMapping(
15191557
numpy_func=_true_color_numpy,
15201558
dask_func=_true_color_dask,

0 commit comments

Comments
 (0)