Skip to content

Commit 6884e01

Browse files
committed
Add 2D image bilateral example with astronaut figures and PSNR
1 parent 5603dfa commit 6884e01

7 files changed

Lines changed: 240 additions & 1 deletion

docs/user_guide/bilateral_permutohedral_filters.md

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Bilateral and Permutohedral Filters
22

33
**Created**: 2026-04-30 19:55:27 PST
4-
**Edited**: 2026-04-30 21:30:00 PST
4+
**Edited**: 2026-04-30 22:50:38 PST
55

66
WarpConvNet ships three families of edge-preserving filters for point clouds
77
and high-dimensional feature volumes. They differ in the underlying spatial
@@ -159,6 +159,89 @@ Sinkhorn-style bistochastization runs by default; disable via the
159159
underlying `bilateral_solver(grid, ..., bistochastize=False)` if extreme
160160
confidence values cause non-finite scaling factors.
161161

162+
## Worked example: 2D image denoising
163+
164+
Image denoising is the canonical bilateral demo: each pixel is a point with
165+
a 2D position and a 3D color, the guide is `concat(xy/sigma_xy, rgb/sigma_rgb)`, and the value being filtered is the noisy color itself.
166+
167+
The example below applies all three filter families to the NASA `astronaut`
168+
test image (public domain, shipped with `scikit-image`) corrupted with
169+
Gaussian noise of variance 0.01. End-to-end times are on an RTX 6000 Ada at
170+
512×512 = 262k "points" with $\sigma_{xy} = 4$, $\sigma_{rgb} = 0.1$.
171+
172+
### Input
173+
174+
<table>
175+
<tr>
176+
<td align="center"><b>Original</b></td>
177+
<td align="center"><b>Noisy (Gaussian, var=0.01) — 20.70 dB</b></td>
178+
</tr>
179+
<tr>
180+
<td><img src="img/astronaut_original.jpg" alt="Original astronaut" width="100%"></td>
181+
<td><img src="img/astronaut_noisy.jpg" alt="Noisy astronaut" width="100%"></td>
182+
</tr>
183+
</table>
184+
185+
### Output
186+
187+
<table>
188+
<tr>
189+
<td align="center"><b>KNN (k=24) — 23.67 dB / ~3.3 s</b></td>
190+
<td align="center"><b>Grid — 23.94 dB / ~63 ms</b></td>
191+
<td align="center"><b>Permutohedral — 24.68 dB / ~11 ms</b></td>
192+
</tr>
193+
<tr>
194+
<td><img src="img/astronaut_knn.jpg" alt="KNN bilateral" width="100%"></td>
195+
<td><img src="img/astronaut_grid.jpg" alt="Grid bilateral" width="100%"></td>
196+
<td><img src="img/astronaut_permutohedral.jpg" alt="Permutohedral bilateral" width="100%"></td>
197+
</tr>
198+
</table>
199+
200+
| Filter | Time | PSNR (dB) | Notes |
201+
| ------------------------------ | ------ | --------- | ------------------------------ |
202+
| Noisy input (reference) || 20.70 | Gaussian noise, var = 0.01 |
203+
| `BilateralFilter` (KNN, k=24) | ~3.3 s | 23.67 | Exact Gaussian, $O(N \cdot k)$ |
204+
| `BilateralFilterGrid` | ~63 ms | 23.94 | $d=5$ sparse cube |
205+
| `BilateralPermutohedralFilter` | ~11 ms | **24.68** | $d=5$ permutohedral lattice |
206+
207+
PSNR is computed against the clean original with `data_range=1.0`. The
208+
permutohedral lattice is both fastest *and* highest PSNR here — its
209+
conservative reconstruction (3-tap Gaussian on $(d{+}1)$ lattice axes)
210+
preserves edges slightly better than the $d$-cube grid at this bandwidth,
211+
and noticeably better than the limited-$K$ KNN filter where boundary
212+
neighbors get clipped.
213+
214+
Reproduce with [`examples/bilateral_image_example.py`](https://github.com/NVIDIA/warpconvnet/blob/main/examples/bilateral_image_example.py):
215+
216+
```bash
217+
python examples/bilateral_image_example.py \
218+
--out-dir docs/user_guide/img \
219+
--sigma-xy 4.0 --sigma-rgb 0.1 --noise-var 0.01
220+
```
221+
222+
Sketch of the call site:
223+
224+
```python
225+
from skimage import data, util
226+
import torch
227+
import warpconvnet.nn as wn
228+
229+
img = util.img_as_float(data.astronaut()) # (512, 512, 3)
230+
noisy = util.random_noise(img, mode="gaussian", var=0.01)
231+
h, w, _ = img.shape
232+
233+
ys, xs = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
234+
xy = torch.stack([xs, ys], dim=-1).reshape(-1, 2).float().cuda()
235+
rgb = torch.from_numpy(noisy).reshape(-1, 3).float().cuda()
236+
237+
filt = wn.BilateralPermutohedralFilter(sigma_xyz=4.0, sigma_feat=0.1)
238+
denoised = filt(xy, rgb, rgb).reshape(h, w, 3).clamp(0, 1)
239+
```
240+
241+
The KNN filter is exact-Gaussian but $O(N \cdot k)$; the lattice variants
242+
trade ~10–20% reconstruction error for two orders of magnitude in speed and
243+
become the only viable option above ~$10^5$ points.
244+
162245
## Constraints
163246

164247
- `PackedHashTable128` supports $D \le 7$ axes per key. Lattice
126 KB
Loading
128 KB
Loading
152 KB
Loading
68.4 KB
Loading
119 KB
Loading
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Bilateral filtering on a 2D image — denoising the NASA `astronaut` test image.
5+
6+
Demonstrates the three bilateral families shipped in `warpconvnet.nn`:
7+
8+
- BilateralFilter (KNN / radius)
9+
- BilateralFilterGrid (sparse d-cube lattice)
10+
- BilateralPermutohedralFilter (permutohedral lattice)
11+
12+
We treat each pixel as a point with 2D position (x, y) and 3D color (r, g, b).
13+
The bilateral guide is concat(xy/sigma_xy, rgb/sigma_rgb); the value being
14+
filtered is the noisy color. The astronaut image is in the public domain
15+
(NASA), shipped with scikit-image.
16+
17+
Saves five PNGs to <out-dir>:
18+
astronaut_original.png, astronaut_noisy.png, astronaut_knn.png,
19+
astronaut_grid.png, astronaut_permutohedral.png
20+
21+
Run:
22+
python examples/bilateral_image_example.py --out-dir docs/user_guide/img
23+
"""
24+
25+
import argparse
26+
import time
27+
28+
import numpy as np
29+
import torch
30+
from skimage import data, util
31+
32+
33+
def _to_pixel_pointcloud(img: np.ndarray, device: torch.device):
34+
"""Flatten an (H, W, 3) image into (N, 2) xy + (N, 3) rgb tensors."""
35+
h, w, _ = img.shape
36+
ys, xs = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
37+
xy = np.stack([xs, ys], axis=-1).reshape(-1, 2).astype(np.float32)
38+
rgb = img.reshape(-1, 3).astype(np.float32)
39+
return (
40+
torch.from_numpy(xy).to(device),
41+
torch.from_numpy(rgb).to(device),
42+
)
43+
44+
45+
def _from_pixel_pointcloud(values: torch.Tensor, h: int, w: int) -> np.ndarray:
46+
return values.detach().cpu().numpy().reshape(h, w, 3).clip(0, 1)
47+
48+
49+
def _save_image(path: str, arr: np.ndarray) -> None:
50+
from PIL import Image
51+
52+
arr = (np.clip(arr, 0.0, 1.0) * 255.0).round().astype(np.uint8)
53+
img = Image.fromarray(arr)
54+
if path.lower().endswith((".jpg", ".jpeg")):
55+
img.save(path, quality=92, optimize=True, progressive=True)
56+
else:
57+
img.save(path, optimize=True)
58+
59+
60+
def main():
61+
parser = argparse.ArgumentParser()
62+
parser.add_argument("--out-dir", default="docs/user_guide/img")
63+
parser.add_argument("--noise-var", type=float, default=0.01)
64+
parser.add_argument("--sigma-xy", type=float, default=4.0)
65+
parser.add_argument("--sigma-rgb", type=float, default=0.1)
66+
parser.add_argument("--knn-k", type=int, default=24)
67+
args = parser.parse_args()
68+
import os
69+
70+
os.makedirs(args.out_dir, exist_ok=True)
71+
72+
if not torch.cuda.is_available():
73+
raise SystemExit("CUDA required for bilateral filters")
74+
device = torch.device("cuda")
75+
76+
import warpconvnet.nn as wn
77+
78+
# ---- input image -------------------------------------------------------
79+
img = util.img_as_float(data.astronaut()) # (512, 512, 3) in [0, 1]
80+
noisy = util.random_noise(img, mode="gaussian", var=args.noise_var)
81+
h, w, _ = img.shape
82+
83+
xy, rgb_clean = _to_pixel_pointcloud(img, device)
84+
_, rgb_noisy = _to_pixel_pointcloud(noisy.astype(np.float32), device)
85+
86+
# ---- KNN bilateral -----------------------------------------------------
87+
knn_filter = wn.BilateralFilter(
88+
sigma_xyz=args.sigma_xy,
89+
sigma_feat=args.sigma_rgb,
90+
k=args.knn_k,
91+
mode="knn",
92+
)
93+
torch.cuda.synchronize()
94+
t0 = time.perf_counter()
95+
out_knn = knn_filter(xy, rgb_noisy, rgb_noisy)
96+
torch.cuda.synchronize()
97+
t_knn = time.perf_counter() - t0
98+
99+
# ---- sparse d-cube grid -----------------------------------------------
100+
grid_filter = wn.BilateralFilterGrid(
101+
sigma_xyz=args.sigma_xy,
102+
sigma_feat=args.sigma_rgb,
103+
)
104+
torch.cuda.synchronize()
105+
t0 = time.perf_counter()
106+
out_grid = grid_filter(xy, rgb_noisy, rgb_noisy)
107+
torch.cuda.synchronize()
108+
t_grid = time.perf_counter() - t0
109+
110+
# ---- permutohedral lattice --------------------------------------------
111+
perm_filter = wn.BilateralPermutohedralFilter(
112+
sigma_xyz=args.sigma_xy,
113+
sigma_feat=args.sigma_rgb,
114+
)
115+
torch.cuda.synchronize()
116+
t0 = time.perf_counter()
117+
out_perm = perm_filter(xy, rgb_noisy, rgb_noisy)
118+
torch.cuda.synchronize()
119+
t_perm = time.perf_counter() - t0
120+
121+
# ---- save individual PNGs ---------------------------------------------
122+
knn_img = _from_pixel_pointcloud(out_knn, h, w)
123+
grid_img = _from_pixel_pointcloud(out_grid, h, w)
124+
perm_img = _from_pixel_pointcloud(out_perm, h, w)
125+
outputs = {
126+
"astronaut_original.jpg": img,
127+
"astronaut_noisy.jpg": noisy,
128+
"astronaut_knn.jpg": knn_img,
129+
"astronaut_grid.jpg": grid_img,
130+
"astronaut_permutohedral.jpg": perm_img,
131+
}
132+
for name, arr in outputs.items():
133+
path = os.path.join(args.out_dir, name)
134+
_save_image(path, arr)
135+
print(f"Saved {path}")
136+
137+
# ---- PSNR vs original (data_range=1.0 since img is float in [0, 1]) ---
138+
from skimage.metrics import peak_signal_noise_ratio as psnr
139+
140+
ref = img.astype(np.float32)
141+
psnr_noisy = psnr(ref, np.clip(noisy, 0, 1).astype(np.float32), data_range=1.0)
142+
psnr_knn = psnr(ref, np.clip(knn_img, 0, 1).astype(np.float32), data_range=1.0)
143+
psnr_grid = psnr(ref, np.clip(grid_img, 0, 1).astype(np.float32), data_range=1.0)
144+
psnr_perm = psnr(ref, np.clip(perm_img, 0, 1).astype(np.float32), data_range=1.0)
145+
146+
print()
147+
print(f" {'Stage':<22}{'Time':>10} PSNR (dB)")
148+
print(f" {'-' * 46}")
149+
print(f" {'Noisy input':<22}{'-':>10} {psnr_noisy:6.2f}")
150+
print(f" {'KNN (k=' + str(args.knn_k) + ')':<22}{t_knn*1e3:>8.1f} ms {psnr_knn:6.2f}")
151+
print(f" {'Grid':<22}{t_grid*1e3:>8.1f} ms {psnr_grid:6.2f}")
152+
print(f" {'Permutohedral':<22}{t_perm*1e3:>8.1f} ms {psnr_perm:6.2f}")
153+
154+
155+
if __name__ == "__main__":
156+
main()

0 commit comments

Comments
 (0)