Skip to content

Commit 677cfef

Browse files
Merge pull request #6 from beauagainagainagainagainagain/codex/add-wavelet-based-image-denoising
2 parents 1b8189b + e480c68 commit 677cfef

4 files changed

Lines changed: 133 additions & 0 deletions

File tree

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@
336336
* [Rotation](digital_image_processing/rotation/rotation.py)
337337
* [Sepia](digital_image_processing/sepia.py)
338338
* [Test Digital Image Processing](digital_image_processing/test_digital_image_processing.py)
339+
* [Wavelet Denoising](digital_image_processing/wavelet_denoising.py)
339340

340341
## Divide And Conquer
341342
* [Closest Pair Of Points](divide_and_conquer/closest_pair_of_points.py)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Wavelet-based image denoising example."""
2+
3+
from __future__ import annotations
4+
5+
import math
6+
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
import pywt
10+
from skimage import data
11+
from skimage.metrics import peak_signal_noise_ratio
12+
13+
14+
def im2double(image: np.ndarray) -> np.ndarray:
15+
"""Return the image converted to ``float64`` precision.
16+
17+
If the input image already contains floating point values the image is
18+
returned cast to ``float64`` without further scaling. Integer images are
19+
scaled to the ``[0, 1]`` range, matching MATLAB's :func:`im2double`
20+
behaviour.
21+
"""
22+
23+
if np.issubdtype(image.dtype, np.floating):
24+
return image.astype(np.float64)
25+
26+
info = np.iinfo(image.dtype)
27+
return image.astype(np.float64) / info.max
28+
29+
30+
def normalize_img(image: np.ndarray) -> np.ndarray:
31+
"""Clip image data to the ``[0, 1]`` range."""
32+
33+
return np.clip(image, 0.0, 1.0)
34+
35+
36+
def denoise_image_wavelet(
37+
original_img: np.ndarray,
38+
noise_level: float = 0.1,
39+
wavelet_name: str = "db4",
40+
decomposition_level: int = 3,
41+
rng: np.random.Generator | None = None,
42+
) -> tuple[np.ndarray, np.ndarray]:
43+
"""Denoise an image using wavelet thresholding.
44+
45+
The function adds synthetic Gaussian noise to ``original_img`` and then
46+
performs wavelet thresholding to suppress the noise.
47+
48+
Args:
49+
original_img: Clean input image in the ``[0, 1]`` range.
50+
noise_level: Standard deviation of the synthetic Gaussian noise.
51+
wavelet_name: Name of the wavelet family to use.
52+
decomposition_level: Number of wavelet decomposition levels.
53+
rng: Optional ``numpy`` random number generator for reproducibility.
54+
55+
Returns:
56+
A tuple ``(noisy_img, denoised_img)`` containing the noisy and
57+
denoised images respectively.
58+
"""
59+
60+
if rng is None:
61+
rng = np.random.default_rng()
62+
63+
original_img = im2double(original_img)
64+
65+
noisy_img = original_img + noise_level * rng.standard_normal(original_img.shape)
66+
noisy_img = normalize_img(noisy_img)
67+
68+
coeffs = pywt.wavedec2(noisy_img, wavelet_name, level=decomposition_level)
69+
coeffs_approx = coeffs[0]
70+
coeffs_details = coeffs[1:]
71+
72+
detail_coeffs = [
73+
detail_array.ravel()
74+
for level_details in coeffs_details
75+
for detail_array in level_details
76+
]
77+
if not detail_coeffs:
78+
msg = "Wavelet decomposition did not produce detail coefficients."
79+
raise ValueError(msg)
80+
all_detail_coeffs = np.concatenate(detail_coeffs)
81+
82+
sigma = np.median(np.abs(all_detail_coeffs)) / 0.6745
83+
threshold = sigma * math.sqrt(2.0 * math.log(original_img.size))
84+
85+
denoised_details = [
86+
tuple(pywt.threshold(detail_array, threshold, mode="soft") for detail_array in level_details)
87+
for level_details in coeffs_details
88+
]
89+
coeffs_denoised = [coeffs_approx, *denoised_details]
90+
91+
denoised_img = pywt.waverec2(coeffs_denoised, wavelet_name)
92+
denoised_img = denoised_img[: original_img.shape[0], : original_img.shape[1]]
93+
denoised_img = normalize_img(denoised_img)
94+
95+
return noisy_img, denoised_img
96+
97+
98+
def main() -> None:
99+
"""Run the wavelet denoising example and display the results."""
100+
101+
original_img = im2double(data.camera())
102+
noisy_img, denoised_img = denoise_image_wavelet(original_img)
103+
104+
psnr_noisy = peak_signal_noise_ratio(original_img, noisy_img)
105+
psnr_denoised = peak_signal_noise_ratio(original_img, denoised_img)
106+
107+
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
108+
ax = axes[0]
109+
ax.imshow(original_img, cmap="gray")
110+
ax.set_title("Original Image")
111+
ax.axis("off")
112+
113+
ax = axes[1]
114+
ax.imshow(noisy_img, cmap="gray")
115+
ax.set_title(f"Noisy Image (PSNR: {psnr_noisy:.2f} dB)")
116+
ax.axis("off")
117+
118+
ax = axes[2]
119+
ax.imshow(denoised_img, cmap="gray")
120+
ax.set_title(f"Denoised Image (PSNR: {psnr_denoised:.2f} dB)")
121+
ax.axis("off")
122+
123+
plt.tight_layout()
124+
plt.show()
125+
126+
127+
if __name__ == "__main__":
128+
main()

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ dependencies = [
1616
"lxml>=5.3",
1717
"matplotlib>=3.9.3",
1818
"numpy>=2.1.3",
19+
"PyWavelets>=1.6.0",
1920
"opencv-python>=4.10.0.84",
2021
"pandas>=2.2.3",
2122
"pillow>=11",
2223
"requests>=2.32.3",
2324
"rich>=13.9.4",
25+
"scikit-image>=0.24.0",
2426
"scikit-learn>=1.5.2",
2527
"sphinx-pyproject>=0.3",
2628
"statsmodels>=0.14.4",

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ keras
66
lxml
77
matplotlib
88
numpy
9+
PyWavelets
910
opencv-python
1011
pandas
1112
pdfkit
1213
pillow
1314
requests
1415
rich
16+
scikit-image
1517
scikit-learn
1618
sphinx_pyproject
1719
statsmodels

0 commit comments

Comments
 (0)