Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DIRECTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@
* [Rotation](digital_image_processing/rotation/rotation.py)
* [Sepia](digital_image_processing/sepia.py)
* [Test Digital Image Processing](digital_image_processing/test_digital_image_processing.py)
* [Wavelet Denoising](digital_image_processing/wavelet_denoising.py)

## Divide And Conquer
* [Closest Pair Of Points](divide_and_conquer/closest_pair_of_points.py)
Expand Down
128 changes: 128 additions & 0 deletions digital_image_processing/wavelet_denoising.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Wavelet-based image denoising example."""

from __future__ import annotations

import math

import matplotlib.pyplot as plt
import numpy as np
import pywt
from skimage import data
from skimage.metrics import peak_signal_noise_ratio


def im2double(image: np.ndarray) -> np.ndarray:
"""Return the image converted to ``float64`` precision.

If the input image already contains floating point values the image is
returned cast to ``float64`` without further scaling. Integer images are
scaled to the ``[0, 1]`` range, matching MATLAB's :func:`im2double`
behaviour.
"""

if np.issubdtype(image.dtype, np.floating):
return image.astype(np.float64)

info = np.iinfo(image.dtype)
return image.astype(np.float64) / info.max


def normalize_img(image: np.ndarray) -> np.ndarray:
"""Clip image data to the ``[0, 1]`` range."""

return np.clip(image, 0.0, 1.0)


def denoise_image_wavelet(
original_img: np.ndarray,
noise_level: float = 0.1,
wavelet_name: str = "db4",
decomposition_level: int = 3,
rng: np.random.Generator | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Denoise an image using wavelet thresholding.

The function adds synthetic Gaussian noise to ``original_img`` and then
performs wavelet thresholding to suppress the noise.

Args:
original_img: Clean input image in the ``[0, 1]`` range.
noise_level: Standard deviation of the synthetic Gaussian noise.
wavelet_name: Name of the wavelet family to use.
decomposition_level: Number of wavelet decomposition levels.
rng: Optional ``numpy`` random number generator for reproducibility.

Returns:
A tuple ``(noisy_img, denoised_img)`` containing the noisy and
denoised images respectively.
"""

if rng is None:
rng = np.random.default_rng()

original_img = im2double(original_img)

noisy_img = original_img + noise_level * rng.standard_normal(original_img.shape)
noisy_img = normalize_img(noisy_img)

coeffs = pywt.wavedec2(noisy_img, wavelet_name, level=decomposition_level)
coeffs_approx = coeffs[0]
coeffs_details = coeffs[1:]

detail_coeffs = [
detail_array.ravel()
for level_details in coeffs_details
for detail_array in level_details
]
if not detail_coeffs:
msg = "Wavelet decomposition did not produce detail coefficients."
raise ValueError(msg)
all_detail_coeffs = np.concatenate(detail_coeffs)

sigma = np.median(np.abs(all_detail_coeffs)) / 0.6745
threshold = sigma * math.sqrt(2.0 * math.log(original_img.size))

denoised_details = [
tuple(pywt.threshold(detail_array, threshold, mode="soft") for detail_array in level_details)
for level_details in coeffs_details
]
coeffs_denoised = [coeffs_approx, *denoised_details]

denoised_img = pywt.waverec2(coeffs_denoised, wavelet_name)
denoised_img = denoised_img[: original_img.shape[0], : original_img.shape[1]]
denoised_img = normalize_img(denoised_img)

return noisy_img, denoised_img


def main() -> None:
"""Run the wavelet denoising example and display the results."""

original_img = im2double(data.camera())
noisy_img, denoised_img = denoise_image_wavelet(original_img)

psnr_noisy = peak_signal_noise_ratio(original_img, noisy_img)
psnr_denoised = peak_signal_noise_ratio(original_img, denoised_img)

fig, axes = plt.subplots(1, 3, figsize=(18, 6))
ax = axes[0]
ax.imshow(original_img, cmap="gray")
ax.set_title("Original Image")
ax.axis("off")

ax = axes[1]
ax.imshow(noisy_img, cmap="gray")
ax.set_title(f"Noisy Image (PSNR: {psnr_noisy:.2f} dB)")
ax.axis("off")

ax = axes[2]
ax.imshow(denoised_img, cmap="gray")
ax.set_title(f"Denoised Image (PSNR: {psnr_denoised:.2f} dB)")
ax.axis("off")

plt.tight_layout()
plt.show()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ dependencies = [
"lxml>=5.3",
"matplotlib>=3.9.3",
"numpy>=2.1.3",
"PyWavelets>=1.6.0",
"opencv-python>=4.10.0.84",
"pandas>=2.2.3",
"pillow>=11",
"requests>=2.32.3",
"rich>=13.9.4",
"scikit-image>=0.24.0",
"scikit-learn>=1.5.2",
"sphinx-pyproject>=0.3",
"statsmodels>=0.14.4",
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ keras
lxml
matplotlib
numpy
PyWavelets
opencv-python
pandas
pdfkit
pillow
requests
rich
scikit-image
scikit-learn
sphinx_pyproject
statsmodels
Expand Down