Skip to content

Commit 819ee71

Browse files
authored
Add focal variety statistic (#1040) (#1043)
* Add focal variety statistic (#1040) Add _calc_variety (CPU/numba) and _focal_variety_cuda (GPU) that count distinct non-NaN values in a kernel neighbourhood. Wire into all three dispatch dicts so variety works across numpy, cupy, dask+numpy, and dask+cupy backends. * Add tests for focal variety (#1040) Cover correctness on categorical data, NaN handling, all-NaN windows, single-cell rasters, and dask+numpy backend parity. Update the data_focal_stats fixture to include variety expected values. * Add focal variety user guide notebook (#1040) * Trigger RTD rebuild
1 parent 9a38f2d commit 819ee71

File tree

3 files changed

+365
-5
lines changed

3 files changed

+365
-5
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Focal Variety\n",
8+
"\n",
9+
"Focal variety counts the number of distinct values in a sliding\n",
10+
"neighbourhood window. It is most useful for categorical rasters\n",
11+
"(land-cover, soil type, geology codes) where you want to map\n",
12+
"boundary complexity or patch fragmentation.\n",
13+
"\n",
14+
"This notebook shows how to compute focal variety with\n",
15+
"`xrspatial.focal.focal_stats` across different kernel shapes."
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": null,
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"import numpy as np\n",
25+
"import xarray as xr\n",
26+
"import matplotlib.pyplot as plt\n",
27+
"\n",
28+
"from xrspatial.convolution import circle_kernel, custom_kernel\n",
29+
"from xrspatial.focal import focal_stats"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"## Create a synthetic land-cover raster\n",
37+
"\n",
38+
"We build a 60x60 grid with four land-cover classes arranged in\n",
39+
"quadrants, plus a few scattered patches to make things interesting."
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {},
46+
"outputs": [],
47+
"source": [
48+
"rng = np.random.default_rng(42)\n",
49+
"rows, cols = 60, 60\n",
50+
"\n",
51+
"# Four quadrants: classes 1-4\n",
52+
"lc = np.ones((rows, cols), dtype=np.float64)\n",
53+
"lc[:rows//2, cols//2:] = 2\n",
54+
"lc[rows//2:, :cols//2] = 3\n",
55+
"lc[rows//2:, cols//2:] = 4\n",
56+
"\n",
57+
"# Scatter some class-5 patches\n",
58+
"for _ in range(30):\n",
59+
" r, c = rng.integers(0, rows), rng.integers(0, cols)\n",
60+
" lc[r:r+3, c:c+3] = 5\n",
61+
"\n",
62+
"land_cover = xr.DataArray(lc, dims=['y', 'x'], name='land_cover')\n",
63+
"\n",
64+
"fig, ax = plt.subplots(figsize=(5, 5))\n",
65+
"land_cover.plot(ax=ax, cmap='Set2', add_colorbar=True)\n",
66+
"ax.set_title('Synthetic land-cover raster')\n",
67+
"ax.set_aspect('equal')\n",
68+
"plt.tight_layout()\n",
69+
"plt.show()"
70+
]
71+
},
72+
{
73+
"cell_type": "markdown",
74+
"metadata": {},
75+
"source": [
76+
"## Compute focal variety with a 3x3 box kernel\n",
77+
"\n",
78+
"A 3x3 box kernel counts how many distinct classes appear in the\n",
79+
"immediate 8-connected neighbourhood of each pixel."
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {},
86+
"outputs": [],
87+
"source": [
88+
"kernel_box = np.ones((3, 3))\n",
89+
"result_box = focal_stats(land_cover, kernel_box, stats_funcs=['variety'])\n",
90+
"variety_box = result_box.sel(stats='variety')\n",
91+
"\n",
92+
"fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
93+
"land_cover.plot(ax=axes[0], cmap='Set2', add_colorbar=True)\n",
94+
"axes[0].set_title('Land cover')\n",
95+
"axes[0].set_aspect('equal')\n",
96+
"\n",
97+
"variety_box.plot(ax=axes[1], cmap='YlOrRd', add_colorbar=True)\n",
98+
"axes[1].set_title('Focal variety (3x3 box)')\n",
99+
"axes[1].set_aspect('equal')\n",
100+
"plt.tight_layout()\n",
101+
"plt.show()"
102+
]
103+
},
104+
{
105+
"cell_type": "markdown",
106+
"metadata": {},
107+
"source": [
108+
"Pixels deep inside a uniform quadrant show variety = 1. Pixels on\n",
109+
"boundaries between classes show variety = 2, 3, or 4 depending on\n",
110+
"how many classes meet at that point. The scattered class-5 patches\n",
111+
"create small pockets of higher variety."
112+
]
113+
},
114+
{
115+
"cell_type": "markdown",
116+
"metadata": {},
117+
"source": [
118+
"## Larger kernel: 5x5 circle\n",
119+
"\n",
120+
"Increasing the kernel radius captures more of the surrounding\n",
121+
"landscape, so variety values near boundaries will be higher."
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": null,
127+
"metadata": {},
128+
"outputs": [],
129+
"source": [
130+
"kernel_circle = circle_kernel(2, 2, 2)\n",
131+
"result_circle = focal_stats(land_cover, kernel_circle, stats_funcs=['variety'])\n",
132+
"variety_circle = result_circle.sel(stats='variety')\n",
133+
"\n",
134+
"fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
135+
"variety_box.plot(ax=axes[0], cmap='YlOrRd', add_colorbar=True)\n",
136+
"axes[0].set_title('Variety (3x3 box)')\n",
137+
"axes[0].set_aspect('equal')\n",
138+
"\n",
139+
"variety_circle.plot(ax=axes[1], cmap='YlOrRd', add_colorbar=True)\n",
140+
"axes[1].set_title('Variety (5x5 circle)')\n",
141+
"axes[1].set_aspect('equal')\n",
142+
"plt.tight_layout()\n",
143+
"plt.show()"
144+
]
145+
},
146+
{
147+
"cell_type": "markdown",
148+
"metadata": {},
149+
"source": [
150+
"## Combining variety with other focal stats\n",
151+
"\n",
152+
"You can request variety alongside other statistics in one call.\n",
153+
"Here we grab both range and variety to compare continuous and\n",
154+
"categorical measures of local heterogeneity."
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": null,
160+
"metadata": {},
161+
"outputs": [],
162+
"source": [
163+
"result_combo = focal_stats(land_cover, kernel_box,\n",
164+
" stats_funcs=['range', 'variety'])\n",
165+
"\n",
166+
"fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
167+
"result_combo.sel(stats='range').plot(ax=axes[0], cmap='viridis',\n",
168+
" add_colorbar=True)\n",
169+
"axes[0].set_title('Focal range')\n",
170+
"axes[0].set_aspect('equal')\n",
171+
"\n",
172+
"result_combo.sel(stats='variety').plot(ax=axes[1], cmap='YlOrRd',\n",
173+
" add_colorbar=True)\n",
174+
"axes[1].set_title('Focal variety')\n",
175+
"axes[1].set_aspect('equal')\n",
176+
"plt.tight_layout()\n",
177+
"plt.show()"
178+
]
179+
},
180+
{
181+
"cell_type": "markdown",
182+
"metadata": {},
183+
"source": [
184+
"Range measures the numeric spread (max minus min) while variety\n",
185+
"counts distinct classes. For categorical data, variety is usually\n",
186+
"the more meaningful measure since the numeric distance between\n",
187+
"class codes is arbitrary."
188+
]
189+
}
190+
],
191+
"metadata": {
192+
"kernelspec": {
193+
"display_name": "Python 3",
194+
"language": "python",
195+
"name": "python3"
196+
},
197+
"language_info": {
198+
"name": "python",
199+
"version": "3.10.0"
200+
}
201+
},
202+
"nbformat": 4,
203+
"nbformat_minor": 4
204+
}

xrspatial/focal.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,28 @@ def _calc_var(array):
352352
return np.nanvar(array)
353353

354354

355+
@ngjit
356+
def _calc_variety(array):
357+
"""Count distinct non-NaN values in the flat kernel neighbourhood."""
358+
count = 0
359+
uvals = np.empty(array.size, dtype=array.dtype)
360+
for i in range(array.size):
361+
v = array.flat[i]
362+
if np.isnan(v):
363+
continue
364+
found = False
365+
for j in range(count):
366+
if uvals[j] == v:
367+
found = True
368+
break
369+
if not found:
370+
uvals[count] = v
371+
count += 1
372+
if count == 0:
373+
return np.nan
374+
return np.float64(count)
375+
376+
355377
@ngjit
356378
def _apply_numpy(data, kernel, func):
357379
data = data.astype(np.float32)
@@ -762,6 +784,52 @@ def _focal_var_cuda(data, kernel, out):
762784
out[i, j] = 0.0
763785

764786

787+
@cuda.jit
788+
def _focal_variety_cuda(data, kernel, out):
789+
i, j = cuda.grid(2)
790+
791+
rows, cols = data.shape
792+
if i >= rows or j >= cols:
793+
return
794+
795+
dr = kernel.shape[0] // 2
796+
dc = kernel.shape[1] // 2
797+
798+
# Local buffer for up to 25 unique values (covers kernels up to 5x5).
799+
# For larger kernels the buffer simply fills and stops counting,
800+
# which is an acceptable trade-off for GPU register pressure.
801+
MAX_UNIQ = 25
802+
buf = cuda.local.array(MAX_UNIQ, nb.float32)
803+
count = 0
804+
805+
for k in range(kernel.shape[0]):
806+
for h in range(kernel.shape[1]):
807+
if kernel[k, h] == 0:
808+
continue
809+
810+
ii = i + k - dr
811+
jj = j + h - dc
812+
813+
if 0 <= ii < rows and 0 <= jj < cols:
814+
v = data[ii, jj]
815+
if v != v: # NaN check (NaN != NaN)
816+
continue
817+
# check if already in buffer
818+
found = False
819+
for u in range(count):
820+
if buf[u] == v:
821+
found = True
822+
break
823+
if not found and count < MAX_UNIQ:
824+
buf[count] = v
825+
count += 1
826+
827+
if count == 0:
828+
out[i, j] = math.nan
829+
else:
830+
out[i, j] = float(count)
831+
832+
765833
def _focal_mean_cupy(data, kernel):
766834
out = convolve_2d(data, kernel / kernel.sum())
767835
return out
@@ -852,6 +920,7 @@ def _focal_stats_cupy(agg, kernel, stats_funcs):
852920
min=lambda *args: _focal_stats_func_cupy(*args, func=_focal_min_cuda),
853921
std=lambda *args: _focal_stats_func_cupy(*args, func=_focal_std_cuda),
854922
var=lambda *args: _focal_stats_func_cupy(*args, func=_focal_var_cuda),
923+
variety=lambda *args: _focal_stats_func_cupy(*args, func=_focal_variety_cuda),
855924
)
856925
stats_aggs = []
857926
for stats in stats_funcs:
@@ -873,6 +942,7 @@ def _focal_stats_dask_cupy(agg, kernel, stats_funcs, boundary='nan'):
873942
mean=_focal_mean_cuda, sum=_focal_sum_cuda,
874943
range=_focal_range_cuda, max=_focal_max_cuda,
875944
min=_focal_min_cuda, std=_focal_std_cuda, var=_focal_var_cuda,
945+
variety=_focal_variety_cuda,
876946
)
877947
pad_h = kernel.shape[0] // 2
878948
pad_w = kernel.shape[1] // 2
@@ -902,7 +972,8 @@ def _focal_stats_cpu(agg, kernel, stats_funcs, boundary='nan'):
902972
'range': _calc_range,
903973
'std': _calc_std,
904974
'var': _calc_var,
905-
'sum': _calc_sum
975+
'sum': _calc_sum,
976+
'variety': _calc_variety,
906977
}
907978
stats_aggs = []
908979
for stats in stats_funcs:
@@ -916,13 +987,14 @@ def _focal_stats_cpu(agg, kernel, stats_funcs, boundary='nan'):
916987
def focal_stats(agg,
917988
kernel,
918989
stats_funcs=[
919-
'mean', 'max', 'min', 'range', 'std', 'var', 'sum'
990+
'mean', 'max', 'min', 'range', 'std', 'var',
991+
'sum', 'variety'
920992
],
921993
boundary='nan'):
922994
"""
923995
Calculates statistics of the values within a specified focal neighborhood
924996
for each pixel in an input raster. The statistics types are Mean, Maximum,
925-
Minimum, Range, Standard deviation, Variation and Sum.
997+
Minimum, Range, Standard deviation, Variation, Sum, and Variety.
926998
927999
Parameters
9281000
----------
@@ -934,7 +1006,9 @@ def focal_stats(agg,
9341006
2D array where values of 1 indicate the kernel.
9351007
stats_funcs: list of string
9361008
List of statistics types to be calculated.
937-
Default set to ['mean', 'max', 'min', 'range', 'std', 'var', 'sum'].
1009+
Default set to ['mean', 'max', 'min', 'range', 'std', 'var',
1010+
'sum', 'variety']. ``'variety'`` counts the number of distinct
1011+
non-NaN values in the neighbourhood (useful for categorical rasters).
9381012
boundary : str, default='nan'
9391013
How to handle edges where the kernel extends beyond the raster.
9401014
``'nan'`` -- fill missing neighbours with NaN (default).

0 commit comments

Comments
 (0)