Skip to content

Commit dbf2f3f

Browse files
committed
Replace FFT-based wavelet differentiation with wavelet basis derivative operator
1 parent 5bcbe29 commit dbf2f3f

1 file changed

Lines changed: 111 additions & 70 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 111 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Methods based on fitting basis functions to data"""
2+
from functools import lru_cache
23
from warnings import warn
34
import numpy as np
45
from scipy import sparse
6+
from scipy.interpolate import CubicSpline
57
import pywt
68

79
from pynumdiff.utils import utility
@@ -136,14 +138,94 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
136138
return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis)
137139

138140

141+
@lru_cache(maxsize=32)
142+
def _wavelet_derivative_synthesis_matrix(N, dt, wavelet, level, mode):
143+
"""Build sparse samples of d/dt of the inverse-DWT synthesis basis.
144+
145+
For a fixed wavelet/level/mode/length, wavedec/waverec define a linear
146+
synthesis map
147+
148+
x(t_n) = sum_k c_k phi_k(t_n).
149+
150+
This routine samples phi'_k(t_n) once, stores those samples sparsely, and
151+
lets waveletdiff compute
152+
153+
x'(t_n) = sum_k c_k phi'_k(t_n)
154+
155+
without differentiating the reconstructed signal. The derivative samples
156+
are obtained from a local cubic interpolant of each compactly supported
157+
synthesis basis vector; this is bookkeeping on the basis functions, not a
158+
finite-difference derivative of the data.
159+
"""
160+
zero = np.zeros(N)
161+
template = pywt.wavedec(zero, wavelet, level=level, mode=mode)
162+
coeff_lengths = tuple(len(c) for c in template)
163+
coeff_offsets = np.cumsum((0,) + coeff_lengths[:-1])
164+
n_coeffs = sum(coeff_lengths)
165+
t = np.arange(N, dtype=float) * dt
166+
167+
rows, cols, vals = [], [], []
168+
eps = 1e-12
169+
170+
for band, (offset, length) in enumerate(zip(coeff_offsets, coeff_lengths)):
171+
for local_idx in range(length):
172+
coeffs = [np.zeros_like(c, dtype=float) for c in template]
173+
coeffs[band][local_idx] = 1.0
174+
basis = pywt.waverec(coeffs, wavelet, mode=mode)[:N]
175+
176+
# Basis functions are compactly supported, but boundary extension can
177+
# split support across the two ends. Differentiating only the active
178+
# samples keeps the matrix sparse and avoids global sinusoidal bases.
179+
active = np.flatnonzero(np.abs(basis) > eps)
180+
if active.size == 0:
181+
continue
182+
183+
# Include one-sample padding around active support so the cubic has
184+
# enough context near the edges of the support. If support wraps or
185+
# covers most of the signal, fall back to all samples.
186+
support = np.zeros(N, dtype=bool)
187+
support[active] = True
188+
support[np.maximum(active - 1, 0)] = True
189+
support[np.minimum(active + 1, N - 1)] = True
190+
idx = np.flatnonzero(support)
191+
if idx.size < 4 or (idx[-1] - idx[0] + 1) > 2 * idx.size:
192+
idx = np.arange(N)
193+
194+
# CubicSpline requires strictly increasing x and at least two points.
195+
# With >=4 points the not-a-knot default is well-defined; with fewer,
196+
# fall back to clamped end slopes of zero.
197+
bc_type = 'not-a-knot' if idx.size >= 4 else ((1, 0.0), (1, 0.0))
198+
spline = CubicSpline(t[idx], basis[idx], bc_type=bc_type, extrapolate=False)
199+
deriv_vals = spline(t[idx], 1)
200+
keep = np.isfinite(deriv_vals) & (np.abs(deriv_vals) > eps)
201+
202+
rows.extend(idx[keep])
203+
cols.extend(np.full(np.count_nonzero(keep), offset + local_idx))
204+
vals.extend(deriv_vals[keep])
205+
206+
return sparse.csr_matrix((vals, (rows, cols)), shape=(N, n_coeffs)), coeff_lengths
207+
208+
209+
def _flatten_wavelet_coeffs(coeffs):
210+
"""Stack a wavedec coefficient list into a 2-D coefficient matrix."""
211+
return np.vstack([c for band in coeffs for c in band])
212+
213+
139214
def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='periodization'):
140-
"""Smooth and differentiate noisy data via discrete wavelet denoising.
215+
"""Smooth and differentiate noisy data with a wavelet-basis derivative sum.
141216
142-
Decomposes x into wavelet detail and approximation coefficients, soft-thresholds
143-
the detail coefficients to remove noise using the Donoho & Johnstone (1994)
144-
universal threshold estimator, reconstructs a smoothed signal, then
145-
differentiates analytically by applying derivative reconstruction filters to
146-
the denoised wavelet coefficients.
217+
Decomposes x into wavelet approximation/detail coefficients, soft-thresholds
218+
the detail coefficients to denoise, reconstructs a smoothed signal, and then
219+
estimates the derivative directly from the denoised wavelet coefficients:
220+
221+
x(t_n) = sum_k c_k phi_k(t_n)
222+
x'(t_n) = sum_k c_k phi'_k(t_n)
223+
224+
The first sum is the ordinary inverse wavelet transform. The second sum is
225+
evaluated by precomputing sparse samples of the derivative of each synthesis
226+
basis function and multiplying that sparse matrix by the denoised
227+
coefficients. This avoids the previous reconstruct-then-FFT derivative path
228+
and does not call finite differences or np.gradient on the signal.
147229
148230
Because the DWT requires uniform spacing, this method only accepts a scalar
149231
time step dt (not a vector of sample times). For non-uniformly sampled data,
@@ -152,24 +234,13 @@ def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='p
152234
:param np.array x: data to differentiate. May be multidimensional; see :code:`axis`.
153235
:param float dt: uniform time step between samples.
154236
:param str wavelet: PyWavelets wavelet name, e.g. 'db4', 'sym4', 'coif2'.
155-
'db4' is a solid general-purpose default. Biorthogonal wavelets such as
156-
'bior2.2' or 'bior4.4' are symmetric and designed for smooth reconstruction
157-
but may need a lower threshold value.
158237
:param int level: decomposition depth. None (default) resolves to
159-
min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short
160-
signals. Increase for heavily oversampled data.
238+
min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short signals.
161239
:param float threshold: soft-thresholding scale factor in [0, inf).
162-
Multiplies the universal threshold sigma * sqrt(2 * log(N)).
163-
threshold=1.0 is the classical Donoho & Johnstone universal threshold
164-
and is the recommended starting point. Values < 1.0 give less smoothing;
165-
values > 1.0 give more aggressive smoothing. This parameter maps onto
166-
tvgamma in the pynumdiff.optimize framework.
167240
:param int axis: axis along which to differentiate (default 0).
168241
:param str mode: PyWavelets signal extension mode passed to wavedec/waverec.
169-
'periodization' (default) keeps coefficient arrays exactly length N and
170-
is the most numerically stable choice for differentiation. 'reflect' is
171-
a good alternative for clearly non-periodic signals.
172-
See pywt.Modes.modes for all options.
242+
'periodization' keeps coefficient arrays compact; 'reflect' is often a
243+
better choice for clearly non-periodic signals.
173244
:return: - **x_hat** (np.array) -- estimated (smoothed) x
174245
- **dxdt_hat** (np.array) -- estimated derivative of x
175246
"""
@@ -180,19 +251,11 @@ def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='p
180251
)
181252

182253
N = x.shape[axis]
183-
184-
# Bring axis of differentiation to front so each column of x_flat is one
185-
# signal to differentiate. moveaxis returns a view with updated strides,
186-
# so ascontiguousarray ensures the subsequent reshape is zero-copy.
187254
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0))
188255
shape = x_work.shape
189-
x_flat = x_work.reshape(N, -1) # (N, M)
256+
x_flat = x_work.reshape(N, -1)
190257
M = x_flat.shape[1]
191258

192-
# Conservative level cap: pywt's default uses the maximum possible level,
193-
# which can over-decompose short signals and wash out meaningful detail.
194-
# Capping at 5 keeps at least 2^5 = 32 samples in the coarsest subband,
195-
# which is enough to represent a smooth approximation without artefacts.
196259
if level is None:
197260
max_level = pywt.dwt_max_level(N, wavelet)
198261
level = min(max_level, 5)
@@ -216,57 +279,35 @@ def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='p
216279
):
217280
coeffs_all[i][:, col] = c
218281

219-
# Vectorised noise estimation and soft-thresholding over all columns at once.
220-
#
221-
# Soft-thresholding achieves smoothing by shrinking wavelet detail
222-
# coefficients toward zero: coefficients whose magnitude is below the
223-
# threshold (mostly noise) are zeroed out, while large coefficients (true
224-
# signal features) are kept but reduced by the threshold amount. Only detail
225-
# levels (indices 1..n_levels-1) are thresholded; the coarse approximation
226-
# coefficients (index 0) are left untouched.
227-
#
228-
# sigma: robust noise-level estimate via the median absolute deviation of
229-
# the finest detail level. Dividing by 0.6745 converts MAD to an
230-
# estimate of the Gaussian standard deviation.
231-
# thresh: per-column Donoho & Johnstone (1994) universal threshold,
232-
# sigma * sqrt(2 * log(N)), scaled by the user-supplied `threshold`.
282+
# Robust noise estimate from finest details, then Donoho-Johnstone
283+
# soft-thresholding on detail bands only.
233284
sigma = np.median(np.abs(coeffs_all[-1]), axis=0) / 0.6745
234-
np.maximum(sigma, 1e-10, out=sigma) # floor avoids zero threshold on clean signals
235-
236-
thresh = threshold * sigma * np.sqrt(2 * np.log(N)) # shape (M,)
237-
285+
np.maximum(sigma, 1e-10, out=sigma)
286+
thresh = threshold * sigma * np.sqrt(2 * np.log(N))
238287
coeffs_denoised = [coeffs_all[0]] + [
239288
pywt.threshold(c, thresh[np.newaxis, :], mode='soft')
240289
for c in coeffs_all[1:]
241290
]
242291

243-
# Reconstruct x_hat and differentiate column by column.
244-
# pywt.waverec is 1-D only, so the column loop is unavoidable here;
245-
# the vectorised operations above have already moved all Python-level
246-
# arithmetic outside this loop.
247-
#
248-
# After wavelet denoising we have a smooth, noise-free signal. We
249-
# differentiate it analytically in the Fourier domain: multiplying the
250-
# FFT by i*omega is equivalent to applying the derivative operator exactly,
251-
# with no finite-difference truncation error. This keeps the two concerns
252-
# cleanly separated — wavelets handle denoising, Fourier handles
253-
# differentiation.
254-
x_hat_flat = np.empty_like(x_flat)
255-
dxdt_hat_flat = np.empty_like(x_flat)
256-
257-
# Angular frequency axis for a length-N signal sampled at dt.
258-
# fftfreq returns cycles/sample; multiplying by 2*pi/dt gives rad/s.
259-
k = np.fft.fftfreq(N, d=dt) * 2 * np.pi
292+
Dphi, matrix_coeff_lengths = _wavelet_derivative_synthesis_matrix(
293+
N, float(dt), wavelet, int(level), mode
294+
)
295+
if tuple(coeff_lengths) != tuple(matrix_coeff_lengths):
296+
raise RuntimeError("Cached wavelet derivative matrix coefficient layout does not match wavedec output.")
297+
298+
x_hat_flat = np.empty_like(x_flat)
299+
coeffs_flat = np.empty((sum(coeff_lengths), M), dtype=x_flat.dtype)
300+
offsets = np.cumsum((0,) + tuple(coeff_lengths[:-1]))
260301

261302
for col in range(M):
262303
col_coeffs = [coeffs_denoised[i][:, col] for i in range(n_levels)]
263-
x_hat_col = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
264-
x_hat_flat[:, col] = x_hat_col
265-
X = np.fft.fft(x_hat_col)
266-
dxdt_hat_flat[:, col] = np.real(np.fft.ifft(1j * k * X))
304+
x_hat_flat[:, col] = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
305+
for i, (offset, length) in enumerate(zip(offsets, coeff_lengths)):
306+
coeffs_flat[offset:offset + length, col] = coeffs_denoised[i][:, col]
307+
308+
dxdt_hat_flat = Dphi @ coeffs_flat
267309

268-
# Restore original shape and axis order.
269-
x_hat = np.moveaxis(x_hat_flat.reshape(shape), 0, axis)
310+
x_hat = np.moveaxis(x_hat_flat.reshape(shape), 0, axis)
270311
dxdt_hat = np.moveaxis(dxdt_hat_flat.reshape(shape), 0, axis)
271312

272313
return x_hat, dxdt_hat

0 commit comments

Comments
 (0)