Skip to content

Commit 62919a0

Browse files
committed
add 14 new classify tests: mutation safety, edge cases, cross-backend consistency
- Missing backend: natural_breaks dask+cupy num_sample - Input mutation: verify all 5 functions don't modify input DataArray - Untested path: natural_breaks with num_sample=None - Edge cases: equal_interval k=1, all-NaN input for equal_interval and natural_breaks - Name parameter: verify default and custom name on all 5 functions - Cross-backend: verify natural_breaks cupy and dask match numpy results on a separate 10x10 dataset
1 parent 7ce3707 commit 62919a0

File tree

1 file changed

+149
-1
lines changed

1 file changed

+149
-1
lines changed

xrspatial/tests/test_classify.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import xarray as xr
44

55
from xrspatial import binary, equal_interval, natural_breaks, quantile, reclassify
6-
from xrspatial.tests.general_checks import (create_test_raster,
6+
from xrspatial.tests.general_checks import (assert_input_data_unmodified,
7+
create_test_raster,
78
cuda_and_cupy_available,
89
dask_array_available,
910
general_output_checks)
@@ -350,3 +351,150 @@ def test_natural_breaks_dask_numpy_num_sample(result_natural_breaks_num_sample):
350351
k, num_sample, expected_result = result_natural_breaks_num_sample
351352
dask_natural_breaks = natural_breaks(dask_agg, k=k, num_sample=num_sample)
352353
general_output_checks(dask_agg, dask_natural_breaks, expected_result, verify_dtype=True)
354+
355+
356+
@dask_array_available
357+
@cuda_and_cupy_available
358+
def test_natural_breaks_dask_cupy_num_sample(result_natural_breaks_num_sample):
359+
dask_cupy_agg = input_data('dask+cupy')
360+
k, num_sample, expected_result = result_natural_breaks_num_sample
361+
dask_cupy_natural_breaks = natural_breaks(dask_cupy_agg, k=k, num_sample=num_sample)
362+
general_output_checks(
363+
dask_cupy_agg, dask_cupy_natural_breaks, expected_result, verify_dtype=True)
364+
365+
366+
# --- Input mutation tests ---
367+
# Classification functions must not modify the input DataArray.
368+
# natural_breaks is most critical because _run_jenks sorts in-place.
369+
370+
def test_binary_does_not_modify_input():
371+
agg = input_data()
372+
original = agg.copy(deep=True)
373+
binary(agg, [1, 2, 3])
374+
assert_input_data_unmodified(original, agg)
375+
376+
377+
def test_reclassify_does_not_modify_input():
378+
agg = input_data()
379+
original = agg.copy(deep=True)
380+
reclassify(agg, bins=[10, 15, np.inf], new_values=[1, 2, 3])
381+
assert_input_data_unmodified(original, agg)
382+
383+
384+
def test_quantile_does_not_modify_input():
385+
agg = input_data()
386+
original = agg.copy(deep=True)
387+
quantile(agg, k=5)
388+
assert_input_data_unmodified(original, agg)
389+
390+
391+
def test_natural_breaks_does_not_modify_input():
392+
agg = input_data()
393+
original = agg.copy(deep=True)
394+
natural_breaks(agg, k=5)
395+
assert_input_data_unmodified(original, agg)
396+
397+
398+
def test_equal_interval_does_not_modify_input():
399+
agg = input_data()
400+
original = agg.copy(deep=True)
401+
equal_interval(agg, k=3)
402+
assert_input_data_unmodified(original, agg)
403+
404+
405+
# --- num_sample=None test ---
406+
# Tests the code path where all data is used without sampling.
407+
# For the test data (20 elements), this produces the same result
408+
# as default num_sample=20000 since 20000 > 20.
409+
410+
def test_natural_breaks_numpy_num_sample_none(result_natural_breaks):
411+
numpy_agg = input_data()
412+
k, expected_result = result_natural_breaks
413+
result = natural_breaks(numpy_agg, k=k, num_sample=None)
414+
general_output_checks(numpy_agg, result, expected_result, verify_dtype=True)
415+
416+
417+
# --- Edge cases for equal_interval ---
418+
419+
def test_equal_interval_k_equals_1():
420+
agg = input_data()
421+
result = equal_interval(agg, k=1)
422+
result_data = result.data
423+
# All finite values should be in class 0
424+
finite_mask = np.isfinite(result_data)
425+
assert np.all(result_data[finite_mask] == 0)
426+
# Non-finite input positions should be NaN in output
427+
input_finite = np.isfinite(agg.data)
428+
assert np.all(np.isnan(result_data[~input_finite]))
429+
430+
431+
# --- All-NaN edge cases ---
432+
# These document current failure behavior for degenerate inputs.
433+
434+
def test_equal_interval_all_nan():
435+
data = np.full((4, 5), np.nan)
436+
agg = xr.DataArray(data)
437+
with pytest.raises(ValueError):
438+
equal_interval(agg, k=3)
439+
440+
441+
def test_natural_breaks_all_nan():
442+
data = np.full((4, 5), np.nan)
443+
agg = xr.DataArray(data)
444+
with pytest.raises(ValueError):
445+
natural_breaks(agg, k=3)
446+
447+
448+
# --- Name parameter tests ---
449+
450+
def test_output_name_default():
451+
agg = input_data()
452+
assert binary(agg, [1, 2]).name == 'binary'
453+
assert reclassify(agg, [10, 15], [1, 2]).name == 'reclassify'
454+
assert quantile(agg, k=3).name == 'quantile'
455+
assert natural_breaks(agg, k=3).name == 'natural_breaks'
456+
assert equal_interval(agg, k=3).name == 'equal_interval'
457+
458+
459+
def test_output_name_custom():
460+
agg = input_data()
461+
assert binary(agg, [1, 2], name='custom').name == 'custom'
462+
assert reclassify(agg, [10, 15], [1, 2], name='custom').name == 'custom'
463+
assert quantile(agg, k=3, name='custom').name == 'custom'
464+
assert natural_breaks(agg, k=3, name='custom').name == 'custom'
465+
assert equal_interval(agg, k=3, name='custom').name == 'custom'
466+
467+
468+
# --- Cross-backend consistency for natural_breaks ---
469+
# Verifies that cupy/dask backends produce identical results to numpy
470+
# using a different dataset (10x10 arange) than the fixture tests.
471+
472+
@cuda_and_cupy_available
473+
def test_natural_breaks_cupy_matches_numpy():
474+
import cupy as cp
475+
elevation = np.arange(100, dtype=np.float64).reshape(10, 10)
476+
numpy_agg = xr.DataArray(elevation)
477+
cupy_agg = xr.DataArray(cp.asarray(elevation))
478+
479+
k = 5
480+
numpy_result = natural_breaks(numpy_agg, k=k)
481+
cupy_result = natural_breaks(cupy_agg, k=k)
482+
483+
np.testing.assert_allclose(
484+
numpy_result.data, cp.asnumpy(cupy_result.data), equal_nan=True
485+
)
486+
487+
488+
@dask_array_available
489+
def test_natural_breaks_dask_matches_numpy():
490+
elevation = np.arange(100, dtype=np.float64).reshape(10, 10)
491+
numpy_agg = xr.DataArray(elevation)
492+
dask_agg = xr.DataArray(da.from_array(elevation, chunks=(5, 5)))
493+
494+
k = 5
495+
numpy_result = natural_breaks(numpy_agg, k=k)
496+
dask_result = natural_breaks(dask_agg, k=k)
497+
498+
np.testing.assert_allclose(
499+
numpy_result.data, dask_result.data.compute(), equal_nan=True
500+
)

0 commit comments

Comments
 (0)