Skip to content

Commit 0086c1d

Browse files
committed
Add waveletdiff: wavelet-based denoising differentiator
- Implements Donoho & Johnstone (1994) universal threshold estimator - Vectorised thresholding over multi-column inputs - Supports variable step size via np.gradient - Supports arbitrary axis via ascontiguousarray + moveaxis - Adds pywt dependency
1 parent 654bd76 commit 0086c1d

3 files changed

Lines changed: 138 additions & 5 deletions

File tree

pynumdiff/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
from .finite_difference import finitediff, first_order, second_order, fourth_order
1616
from .smooth_finite_difference import kerneldiff, meandiff, mediandiff, gaussiandiff, friedrichsdiff, butterdiff
1717
from .polynomial_fit import splinediff, polydiff, savgoldiff
18-
from .basis_fit import spectraldiff, rbfdiff
18+
from .basis_fit import spectraldiff, rbfdiff, waveletdiff
1919
from .total_variation_regularization import iterative_velocity
2020
from .kalman_smooth import kalman_filter, rts_smooth, rtsdiff, constant_velocity, constant_acceleration, constant_jerk

pynumdiff/basis_fit.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from warnings import warn
33
import numpy as np
44
from scipy import sparse
5+
import pywt
56

67
from pynumdiff.utils import utility
78

@@ -137,3 +138,125 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
137138
alpha = sparse.linalg.spsolve(rbf_regularized, x) # solve sparse system targeting the noisy data, O(N sigma^2)
138139

139140
return rbf @ alpha, drbfdt @ alpha # find samples of reconstructions using the smooth bases
141+
142+
143+
def waveletdiff(x, dt_or_t, wavelet='db4', level=None, threshold=1.0, axis=0, mode='periodization'):
144+
"""Smooth and differentiate noisy data via discrete wavelet denoising.
145+
146+
Decomposes x into wavelet detail and approximation coefficients, soft-thresholds
147+
the detail coefficients to remove noise using the Donoho & Johnstone (1994)
148+
universal threshold estimator, reconstructs a smoothed signal, then
149+
differentiates with finite differences via np.gradient.
150+
151+
:param np.array x: data to differentiate
152+
:param float or array dt_or_t: scalar dt or array of sample times. If an
153+
array is provided it is passed directly to np.gradient, giving correct
154+
results for non-uniformly sampled data.
155+
:param str wavelet: PyWavelets wavelet name, e.g. 'db4', 'sym4', 'coif2'.
156+
'db4' is a solid general-purpose default. Biorthogonal wavelets such as
157+
'bior2.2' or 'bior4.4' are symmetric and designed for smooth reconstruction
158+
but may need a lower threshold value.
159+
:param int level: decomposition depth. None (default) resolves to
160+
min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short
161+
signals. Increase for heavily oversampled data.
162+
:param float threshold: soft-thresholding scale factor in [0, inf).
163+
Multiplies the universal threshold sigma * sqrt(2 * log(N)).
164+
threshold=1.0 is the classical Donoho & Johnstone universal threshold
165+
and is the recommended starting point. Values < 1.0 give less smoothing;
166+
values > 1.0 give more aggressive smoothing. This parameter maps onto
167+
tvgamma in the pynumdiff.optimize framework.
168+
:param int axis: axis along which to differentiate (default 0).
169+
:param str mode: PyWavelets signal extension mode passed to wavedec/waverec.
170+
'periodization' (default) keeps coefficient arrays exactly length N and
171+
is the most numerically stable choice for differentiation. 'reflect' is
172+
a good alternative for clearly non-periodic signals.
173+
See pywt.Modes.modes for all options.
174+
:return: - **x_hat** (np.array) -- estimated (smoothed) x
175+
- **dxdt_hat** (np.array) -- estimated derivative of x
176+
"""
177+
N = x.shape[axis]
178+
179+
# Axis normalisation — bring target axis to front.
180+
# Skip moveaxis when axis is already 0 to avoid an unnecessary allocation.
181+
# When we do move, call ascontiguousarray immediately so the subsequent
182+
# reshape is guaranteed zero-copy.
183+
if axis == 0:
184+
x_work = x if x.flags['C_CONTIGUOUS'] else np.ascontiguousarray(x)
185+
else:
186+
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0))
187+
188+
shape = x_work.shape
189+
x_flat = x_work.reshape(N, -1) # (N, M) contiguous, no hidden copy
190+
M = x_flat.shape[1]
191+
192+
if np.isscalar(dt_or_t):
193+
grad_arg = dt_or_t
194+
else:
195+
if len(dt_or_t) != N:
196+
raise ValueError(
197+
"`dt_or_t` array must have the same length as x along `axis`."
198+
)
199+
grad_arg = dt_or_t # np.gradient accepts a full coordinate array
200+
201+
# Conservative level default avoids over-decomposing short signals
202+
# (pywt default uses the maximum possible level).
203+
if level is None:
204+
max_level = pywt.dwt_max_level(N, wavelet)
205+
level = min(max_level, 5)
206+
207+
# Decompose all columns and stack coefficients into 2-D arrays of shape
208+
# (coeff_len_i, M). Probing column 0 first lets us pre-allocate correctly;
209+
# the probe result is reused for col 0 so we pay N+1 wavedec calls total.
210+
_probe = pywt.wavedec(x_flat[:, 0], wavelet, level=level, mode=mode)
211+
coeff_lengths = [len(c) for c in _probe]
212+
n_levels = len(_probe)
213+
214+
coeffs_all = [
215+
np.empty((coeff_lengths[i], M), dtype=x_flat.dtype)
216+
for i in range(n_levels)
217+
]
218+
for i, c in enumerate(_probe):
219+
coeffs_all[i][:, 0] = c
220+
221+
for col in range(1, M):
222+
for i, c in enumerate(
223+
pywt.wavedec(x_flat[:, col], wavelet, level=level, mode=mode)
224+
):
225+
coeffs_all[i][:, col] = c
226+
227+
# Vectorised noise estimation and soft-thresholding over all columns at once.
228+
# sigma: robust MAD estimator from finest detail level, shape (M,).
229+
# thresh: per-column universal threshold, shape (M,).
230+
# Approximation coefficients (index 0) are left untouched; only detail
231+
# levels (indices 1..n_levels-1) are thresholded.
232+
sigma = np.median(np.abs(coeffs_all[-1]), axis=0) / 0.6745
233+
np.maximum(sigma, 1e-10, out=sigma) # floor avoids zero threshold on clean signals
234+
235+
thresh = threshold * sigma * np.sqrt(2 * np.log(N)) # shape (M,)
236+
237+
coeffs_denoised = [coeffs_all[0]] + [
238+
pywt.threshold(c, thresh[np.newaxis, :], mode='soft')
239+
for c in coeffs_all[1:]
240+
]
241+
242+
# Reconstruct and differentiate — pywt.waverec is 1-D only so a column
243+
# loop remains, but all Python-level arithmetic has been moved out above.
244+
x_hat_flat = np.empty_like(x_flat)
245+
dxdt_hat_flat = np.empty_like(x_flat)
246+
247+
for col in range(M):
248+
col_coeffs = [coeffs_denoised[i][:, col] for i in range(n_levels)]
249+
x_hat_col = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
250+
x_hat_flat[:, col] = x_hat_col
251+
dxdt_hat_flat[:, col] = np.gradient(x_hat_col, grad_arg)
252+
253+
# Restore original shape and axis order.
254+
# moveaxis on the way out is only needed when we moved on the way in.
255+
x_hat = x_hat_flat.reshape(shape)
256+
dxdt_hat = dxdt_hat_flat.reshape(shape)
257+
258+
if axis != 0:
259+
x_hat = np.moveaxis(x_hat, 0, axis)
260+
dxdt_hat = np.moveaxis(dxdt_hat, 0, axis)
261+
262+
return x_hat, dxdt_hat

pynumdiff/tests/test_diff_methods.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..smooth_finite_difference import kerneldiff, mediandiff, meandiff, gaussiandiff, friedrichsdiff, butterdiff
66
from ..finite_difference import finitediff, first_order, second_order, fourth_order
77
from ..polynomial_fit import polydiff, savgoldiff, splinediff
8-
from ..basis_fit import spectraldiff, rbfdiff
8+
from ..basis_fit import spectraldiff, rbfdiff, waveletdiff
99
from ..total_variation_regularization import velocity, acceleration, jerk, iterative_velocity, smooth_acceleration
1010
from ..kalman_smooth import rtsdiff, constant_velocity, constant_acceleration, constant_jerk, robustdiff
1111
from ..linear_model import lineardiff
@@ -47,6 +47,8 @@ def spline_irreg_step(*args, **kwargs): return splinediff(*args, **kwargs)
4747
(spline_irreg_step, {'degree':5, 's':2}),
4848
(spectraldiff, {'high_freq_cutoff':0.2}), (spectraldiff, [0.2]),
4949
(rbfdiff, {'sigma':0.5, 'lmbd':0.001}),
50+
(waveletdiff, {'wavelet':'db4', 'threshold':1.0}),
51+
(waveletdiff, {'wavelet':'db4', 'threshold':1.0}),
5052
(constant_velocity, {'r':1e-2, 'q':1e3}), (constant_velocity, [1e-2, 1e3]),
5153
(constant_acceleration, {'r':1e-3, 'q':1e4}), (constant_acceleration, [1e-3, 1e4]),
5254
(constant_jerk, {'r':1e-4, 'q':1e5}), (constant_jerk, [1e-4, 1e5]),
@@ -162,6 +164,12 @@ def spline_irreg_step(*args, **kwargs): return splinediff(*args, **kwargs)
162164
[(-2, -2), (0, 0), (0, -1), (0, 0)],
163165
[(0, 0), (2, 2), (0, 0), (2, 2)],
164166
[(1, 1), (3, 3), (1, 1), (3, 3)]],
167+
waveletdiff: [[(-14, -15), (-14, -14), (-1, -1), (0, 0)],
168+
[(-9, -9), (-8, -8), (0, 0), (1, 1)],
169+
[(-9, -9), (0, 0), (0, 0), (1, 1)],
170+
[(-1, -1), (0, 0), (0, 0), (1, 1)],
171+
[(1, 0), (2, 2), (1, 1), (2, 2)],
172+
[(0, 0), (3, 3), (1, 0), (3, 3)]],
165173
velocity: [[(-25, -25), (-18, -19), (0, -1), (1, 0)],
166174
[(-12, -12), (-11, -12), (-1, -1), (-1, -2)],
167175
[(0, -1), (1, 0), (0, -1), (1, 0)],
@@ -308,7 +316,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
308316
(kerneldiff, {'kernel': 'gaussian', 'window_size': 5}),
309317
(butterdiff, {'filter_order': 3, 'cutoff_freq': 1 - 1e-6}),
310318
(finitediff, {}),
311-
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3})
319+
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
320+
(waveletdiff, {'wavelet': 'db4', 'threshold': 1.0}),
312321
]
313322

314323
# Similar to the error_bounds table, index by method first. But then we test against only one 2D function,
@@ -319,7 +328,8 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
319328
kerneldiff: [(2, 1), (3, 2)],
320329
butterdiff: [(0, -1), (1, -1)],
321330
finitediff: [(0, -1), (1, -1)],
322-
savgoldiff: [(0, -1), (1, 1)]
331+
savgoldiff: [(0, -1), (1, 1)],
332+
waveletdiff: [(1, 0), (2, 1)],
323333
}
324334

325335
@mark.parametrize("multidim_method_and_params", multidim_methods_and_params)
@@ -372,4 +382,4 @@ def test_multidimensionality(multidim_method_and_params, request):
372382
ax2.plot_wireframe(T1, T2, computed_d2)
373383
ax3.plot_wireframe(T1, T2, computed_laplacian, label='computed')
374384
legend = ax3.legend(bbox_to_anchor=(0.7, 0.8)); legend.legend_handles[0].set_facecolor(pyplot.cm.viridis(0.6))
375-
fig.suptitle(f'{diff_method.__name__}', fontsize=16)
385+
fig.suptitle(f'{diff_method.__name__}', fontsize=16)

0 commit comments

Comments
 (0)