|
3 | 3 | import xarray as xr |
4 | 4 |
|
5 | 5 | from xrspatial import binary, equal_interval, natural_breaks, quantile, reclassify |
6 | | -from xrspatial.tests.general_checks import (create_test_raster, |
| 6 | +from xrspatial.tests.general_checks import (assert_input_data_unmodified, |
| 7 | + create_test_raster, |
7 | 8 | cuda_and_cupy_available, |
8 | 9 | dask_array_available, |
9 | 10 | general_output_checks) |
@@ -350,3 +351,150 @@ def test_natural_breaks_dask_numpy_num_sample(result_natural_breaks_num_sample): |
350 | 351 | k, num_sample, expected_result = result_natural_breaks_num_sample |
351 | 352 | dask_natural_breaks = natural_breaks(dask_agg, k=k, num_sample=num_sample) |
352 | 353 | general_output_checks(dask_agg, dask_natural_breaks, expected_result, verify_dtype=True) |
| 354 | + |
| 355 | + |
| 356 | +@dask_array_available |
| 357 | +@cuda_and_cupy_available |
| 358 | +def test_natural_breaks_dask_cupy_num_sample(result_natural_breaks_num_sample): |
| 359 | + dask_cupy_agg = input_data('dask+cupy') |
| 360 | + k, num_sample, expected_result = result_natural_breaks_num_sample |
| 361 | + dask_cupy_natural_breaks = natural_breaks(dask_cupy_agg, k=k, num_sample=num_sample) |
| 362 | + general_output_checks( |
| 363 | + dask_cupy_agg, dask_cupy_natural_breaks, expected_result, verify_dtype=True) |
| 364 | + |
| 365 | + |
| 366 | +# --- Input mutation tests --- |
| 367 | +# Classification functions must not modify the input DataArray. |
| 368 | +# natural_breaks is most critical because _run_jenks sorts in-place. |
| 369 | + |
| 370 | +def test_binary_does_not_modify_input(): |
| 371 | + agg = input_data() |
| 372 | + original = agg.copy(deep=True) |
| 373 | + binary(agg, [1, 2, 3]) |
| 374 | + assert_input_data_unmodified(original, agg) |
| 375 | + |
| 376 | + |
| 377 | +def test_reclassify_does_not_modify_input(): |
| 378 | + agg = input_data() |
| 379 | + original = agg.copy(deep=True) |
| 380 | + reclassify(agg, bins=[10, 15, np.inf], new_values=[1, 2, 3]) |
| 381 | + assert_input_data_unmodified(original, agg) |
| 382 | + |
| 383 | + |
| 384 | +def test_quantile_does_not_modify_input(): |
| 385 | + agg = input_data() |
| 386 | + original = agg.copy(deep=True) |
| 387 | + quantile(agg, k=5) |
| 388 | + assert_input_data_unmodified(original, agg) |
| 389 | + |
| 390 | + |
| 391 | +def test_natural_breaks_does_not_modify_input(): |
| 392 | + agg = input_data() |
| 393 | + original = agg.copy(deep=True) |
| 394 | + natural_breaks(agg, k=5) |
| 395 | + assert_input_data_unmodified(original, agg) |
| 396 | + |
| 397 | + |
| 398 | +def test_equal_interval_does_not_modify_input(): |
| 399 | + agg = input_data() |
| 400 | + original = agg.copy(deep=True) |
| 401 | + equal_interval(agg, k=3) |
| 402 | + assert_input_data_unmodified(original, agg) |
| 403 | + |
| 404 | + |
| 405 | +# --- num_sample=None test --- |
| 406 | +# Tests the code path where all data is used without sampling. |
| 407 | +# For the test data (20 elements), this produces the same result |
| 408 | +# as default num_sample=20000 since 20000 > 20. |
| 409 | + |
| 410 | +def test_natural_breaks_numpy_num_sample_none(result_natural_breaks): |
| 411 | + numpy_agg = input_data() |
| 412 | + k, expected_result = result_natural_breaks |
| 413 | + result = natural_breaks(numpy_agg, k=k, num_sample=None) |
| 414 | + general_output_checks(numpy_agg, result, expected_result, verify_dtype=True) |
| 415 | + |
| 416 | + |
| 417 | +# --- Edge cases for equal_interval --- |
| 418 | + |
| 419 | +def test_equal_interval_k_equals_1(): |
| 420 | + agg = input_data() |
| 421 | + result = equal_interval(agg, k=1) |
| 422 | + result_data = result.data |
| 423 | + # All finite values should be in class 0 |
| 424 | + finite_mask = np.isfinite(result_data) |
| 425 | + assert np.all(result_data[finite_mask] == 0) |
| 426 | + # Non-finite input positions should be NaN in output |
| 427 | + input_finite = np.isfinite(agg.data) |
| 428 | + assert np.all(np.isnan(result_data[~input_finite])) |
| 429 | + |
| 430 | + |
| 431 | +# --- All-NaN edge cases --- |
| 432 | +# These document current failure behavior for degenerate inputs. |
| 433 | + |
| 434 | +def test_equal_interval_all_nan(): |
| 435 | + data = np.full((4, 5), np.nan) |
| 436 | + agg = xr.DataArray(data) |
| 437 | + with pytest.raises(ValueError): |
| 438 | + equal_interval(agg, k=3) |
| 439 | + |
| 440 | + |
| 441 | +def test_natural_breaks_all_nan(): |
| 442 | + data = np.full((4, 5), np.nan) |
| 443 | + agg = xr.DataArray(data) |
| 444 | + with pytest.raises(ValueError): |
| 445 | + natural_breaks(agg, k=3) |
| 446 | + |
| 447 | + |
| 448 | +# --- Name parameter tests --- |
| 449 | + |
| 450 | +def test_output_name_default(): |
| 451 | + agg = input_data() |
| 452 | + assert binary(agg, [1, 2]).name == 'binary' |
| 453 | + assert reclassify(agg, [10, 15], [1, 2]).name == 'reclassify' |
| 454 | + assert quantile(agg, k=3).name == 'quantile' |
| 455 | + assert natural_breaks(agg, k=3).name == 'natural_breaks' |
| 456 | + assert equal_interval(agg, k=3).name == 'equal_interval' |
| 457 | + |
| 458 | + |
| 459 | +def test_output_name_custom(): |
| 460 | + agg = input_data() |
| 461 | + assert binary(agg, [1, 2], name='custom').name == 'custom' |
| 462 | + assert reclassify(agg, [10, 15], [1, 2], name='custom').name == 'custom' |
| 463 | + assert quantile(agg, k=3, name='custom').name == 'custom' |
| 464 | + assert natural_breaks(agg, k=3, name='custom').name == 'custom' |
| 465 | + assert equal_interval(agg, k=3, name='custom').name == 'custom' |
| 466 | + |
| 467 | + |
| 468 | +# --- Cross-backend consistency for natural_breaks --- |
| 469 | +# Verifies that cupy/dask backends produce identical results to numpy |
| 470 | +# using a different dataset (10x10 arange) than the fixture tests. |
| 471 | + |
| 472 | +@cuda_and_cupy_available |
| 473 | +def test_natural_breaks_cupy_matches_numpy(): |
| 474 | + import cupy as cp |
| 475 | + elevation = np.arange(100, dtype=np.float64).reshape(10, 10) |
| 476 | + numpy_agg = xr.DataArray(elevation) |
| 477 | + cupy_agg = xr.DataArray(cp.asarray(elevation)) |
| 478 | + |
| 479 | + k = 5 |
| 480 | + numpy_result = natural_breaks(numpy_agg, k=k) |
| 481 | + cupy_result = natural_breaks(cupy_agg, k=k) |
| 482 | + |
| 483 | + np.testing.assert_allclose( |
| 484 | + numpy_result.data, cp.asnumpy(cupy_result.data), equal_nan=True |
| 485 | + ) |
| 486 | + |
| 487 | + |
| 488 | +@dask_array_available |
| 489 | +def test_natural_breaks_dask_matches_numpy(): |
| 490 | + elevation = np.arange(100, dtype=np.float64).reshape(10, 10) |
| 491 | + numpy_agg = xr.DataArray(elevation) |
| 492 | + dask_agg = xr.DataArray(da.from_array(elevation, chunks=(5, 5))) |
| 493 | + |
| 494 | + k = 5 |
| 495 | + numpy_result = natural_breaks(numpy_agg, k=k) |
| 496 | + dask_result = natural_breaks(dask_agg, k=k) |
| 497 | + |
| 498 | + np.testing.assert_allclose( |
| 499 | + numpy_result.data, dask_result.data.compute(), equal_nan=True |
| 500 | + ) |
0 commit comments