Skip to content

Commit a6b636c

Browse files
authored
added cupy support to true_color function (#866)
1 parent fe4c952 commit a6b636c

File tree

3 files changed

+86
-9
lines changed

3 files changed

+86
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e
176176
| [Normalized Difference Vegetation Index (NDVI)](xrspatial/multispectral.py) | Quantifies vegetation density from red and NIR band difference | ✅️ |✅️ | ✅️ |✅️ |
177177
| [Soil Adjusted Vegetation Index (SAVI)](xrspatial/multispectral.py) | Vegetation index with soil brightness correction factor | ✅️ |✅️ | ✅️ |✅️ |
178178
| [Structure Insensitive Pigment Index (SIPI)](xrspatial/multispectral.py) | Estimates carotenoid-to-chlorophyll ratio for plant stress detection | ✅️ |✅️ | ✅️ |✅️ |
179-
| [True Color](xrspatial/multispectral.py) | Composites red, green, and blue bands into a natural color image | ✅️ || ✅️ ||
179+
| [True Color](xrspatial/multispectral.py) | Composites red, green, and blue bands into a natural color image | ✅️ | | ✅️ | |
180180

181181
-------
182182

xrspatial/multispectral.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,11 +1371,35 @@ def _normalize_data_dask(data, pixel_max, c, th):
13711371

13721372

13731373
def _normalize_data_cupy(data, pixel_max, c, th):
1374-
raise NotImplementedError('Not Supported')
1374+
min_val = cupy.nanmin(data)
1375+
max_val = cupy.nanmax(data)
1376+
range_val = max_val - min_val
1377+
out = cupy.full(data.shape, cupy.nan, dtype=cupy.float32)
1378+
if range_val != 0:
1379+
norm = (data - min_val) / range_val
1380+
norm = 1 / (1 + cupy.exp(c * (th - norm)))
1381+
out = norm * pixel_max
1382+
return out
1383+
1384+
1385+
def _normalize_data_cupy_block(data, min_val, max_val, pixel_max, c, th):
1386+
range_val = max_val - min_val
1387+
out = cupy.full(data.shape, cupy.nan, dtype=cupy.float32)
1388+
if range_val != 0:
1389+
norm = (data - min_val) / range_val
1390+
norm = 1 / (1 + cupy.exp(c * (th - norm)))
1391+
out = norm * pixel_max
1392+
return out
13751393

13761394

13771395
def _normalize_data_dask_cupy(data, pixel_max, c, th):
1378-
raise NotImplementedError('Not Supported')
1396+
min_val = da.nanmin(data)
1397+
max_val = da.nanmax(data)
1398+
out = da.map_blocks(
1399+
_normalize_data_cupy_block, data, min_val, max_val, pixel_max,
1400+
c, th, meta=cupy.array(())
1401+
)
1402+
return out
13791403

13801404

13811405
def _normalize_data(agg, pixel_max, c, th):
@@ -1416,6 +1440,32 @@ def _true_color_dask(r, g, b, nodata, c, th):
14161440
return out
14171441

14181442

1443+
def _true_color_cupy(r, g, b, nodata, c, th):
1444+
pixel_max = 255
1445+
r_data = r.data
1446+
a = cupy.where(
1447+
cupy.logical_or(cupy.isnan(r_data), r_data <= nodata), 0, pixel_max
1448+
).astype(cupy.uint8)
1449+
red = (_normalize_data(r, pixel_max, c, th)).astype(cupy.uint8)
1450+
green = (_normalize_data(g, pixel_max, c, th)).astype(cupy.uint8)
1451+
blue = (_normalize_data(b, pixel_max, c, th)).astype(cupy.uint8)
1452+
out = cupy.stack([red, green, blue, a], axis=-1)
1453+
return out
1454+
1455+
1456+
def _true_color_dask_cupy(r, g, b, nodata, c, th):
1457+
pixel_max = 255
1458+
r_data = r.data
1459+
alpha = da.where(
1460+
da.logical_or(da.isnan(r_data), r_data <= nodata), 0, pixel_max
1461+
).astype(cupy.uint8)
1462+
red = (_normalize_data(r, pixel_max, c, th)).astype(cupy.uint8)
1463+
green = (_normalize_data(g, pixel_max, c, th)).astype(cupy.uint8)
1464+
blue = (_normalize_data(b, pixel_max, c, th)).astype(cupy.uint8)
1465+
out = da.stack([red, green, blue, alpha], axis=-1)
1466+
return out
1467+
1468+
14191469
def true_color(r, g, b, nodata=1, c=10.0, th=0.125, name='true_color'):
14201470
"""
14211471
Create true color composite from a combination of red, green and
@@ -1468,12 +1518,8 @@ def true_color(r, g, b, nodata=1, c=10.0, th=0.125, name='true_color'):
14681518
mapper = ArrayTypeFunctionMapping(
14691519
numpy_func=_true_color_numpy,
14701520
dask_func=_true_color_dask,
1471-
cupy_func=lambda *args: not_implemented_func(
1472-
*args, messages='true_color() does not support cupy backed DataArray', # noqa
1473-
),
1474-
dask_cupy_func=lambda *args: not_implemented_func(
1475-
*args, messages='true_color() does not support dask with cupy backed DataArray', # noqa
1476-
),
1521+
cupy_func=_true_color_cupy,
1522+
dask_cupy_func=_true_color_dask_cupy,
14771523
)
14781524
with warnings.catch_warnings():
14791525
warnings.simplefilter('ignore')

xrspatial/tests/test_multispectral.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,3 +613,34 @@ def test_true_color_numpy_equals_dask_numpy(random_data):
613613
np.testing.assert_allclose(
614614
numpy_result.data, dask_result.compute().data, equal_nan=True
615615
)
616+
617+
618+
@cuda_and_cupy_available
619+
@pytest.mark.parametrize("size", [(2, 4), (10, 15)])
620+
@pytest.mark.parametrize(
621+
"dtype", [np.int32, np.int64, np.uint32, np.uint64, np.float32, np.float64])
622+
@pytest.mark.parametrize("backend", ["cupy", "dask+cupy"])
623+
def test_true_color_gpu(random_data, backend):
624+
# numpy baseline
625+
red_np = create_test_raster(random_data, backend="numpy")
626+
green_np = create_test_raster(random_data, backend="numpy")
627+
blue_np = create_test_raster(random_data, backend="numpy")
628+
numpy_result = true_color(red_np, green_np, blue_np)
629+
630+
# gpu version
631+
red_gpu = create_test_raster(random_data, backend=backend)
632+
green_gpu = create_test_raster(random_data, backend=backend)
633+
blue_gpu = create_test_raster(random_data, backend=backend)
634+
gpu_result = true_color(red_gpu, green_gpu, blue_gpu)
635+
636+
general_output_checks(red_gpu, gpu_result, verify_attrs=False)
637+
638+
gpu_data = gpu_result.data
639+
if hasattr(gpu_data, 'compute'):
640+
gpu_data = gpu_data.compute()
641+
if hasattr(gpu_data, 'get'):
642+
gpu_data = gpu_data.get()
643+
644+
np.testing.assert_allclose(
645+
numpy_result.data, gpu_data, equal_nan=True
646+
)

0 commit comments

Comments
 (0)