Skip to content

Commit baf9946

Browse files
pavelkomarovclaude
andcommitted
Compute wavelet derivative analytically via connection coefficients
Replace the cubic-spline-of-sampled-basis derivative with the exact analytic derivative of the wavelet basis. The denoised reconstruction is the wavelet interpolant x(t) = sum_n a_n phi(t/dt - n), so its derivative is x' = Phi' Phi^-1 x_hat, where Phi and Phi' are circulant samples of the scaling function phi and its derivative phi' at integers. Those samples are the eigenvalue-1 and eigenvalue-1/2 eigenvectors of the wavelet refinement relation (connection coefficients), normalized to reproduce constants and differentiate ramps exactly. No splines or finite differences. - Antisymmetrically extend x_hat before applying the periodic operator so edge derivatives stay accurate instead of spiking at the wrap. - Reject Haar/db1, whose step scaling function has no derivative. - Default to db8 (smoother -> better derivatives on noisy data) and collapse the two old helpers into one operator builder; vectorize the DWT over columns. - Retighten the waveletdiff error bounds for the new, exact derivative. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent dbf2f3f commit baf9946

2 files changed

Lines changed: 117 additions & 160 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 108 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from warnings import warn
44
import numpy as np
55
from scipy import sparse
6-
from scipy.interpolate import CubicSpline
76
import pywt
87

98
from pynumdiff.utils import utility
@@ -139,175 +138,133 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
139138

140139

141140
@lru_cache(maxsize=32)
142-
def _wavelet_derivative_synthesis_matrix(N, dt, wavelet, level, mode):
143-
"""Build sparse samples of d/dt of the inverse-DWT synthesis basis.
144-
145-
For a fixed wavelet/level/mode/length, wavedec/waverec define a linear
146-
synthesis map
147-
148-
x(t_n) = sum_k c_k phi_k(t_n).
149-
150-
This routine samples phi'_k(t_n) once, stores those samples sparsely, and
151-
lets waveletdiff compute
152-
153-
x'(t_n) = sum_k c_k phi'_k(t_n)
154-
155-
without differentiating the reconstructed signal. The derivative samples
156-
are obtained from a local cubic interpolant of each compactly supported
157-
synthesis basis vector; this is bookkeeping on the basis functions, not a
158-
finite-difference derivative of the data.
141+
def _wavelet_derivative_operator(N, dt, wavelet):
142+
"""Build the sparse operators that turn denoised samples into a derivative.
143+
144+
PyWavelets treats the input samples as the finest-level scaling coefficients,
145+
so the denoised reconstruction x_hat represents the continuous interpolant
146+
x(t) = sum_n a_n phi(t/dt - n), where phi is the wavelet's scaling function.
147+
Sampling x and its analytic derivative on the grid t_m = m*dt gives two
148+
convolutions against phi and phi' evaluated at *integers*:
149+
150+
x_hat[m] = sum_n a_n phi(m - n) -> x_hat = Phi @ a
151+
x'(t_m) = (1/dt) sum_n a_n phi'(m - n) -> x' = Phi_prime @ a
152+
153+
so x' = Phi_prime @ Phi^-1 @ x_hat. This is the exact derivative of the
154+
wavelet interpolant, with no spline or finite-difference approximation.
155+
156+
phi and phi' at the integers are the eigenvectors of the wavelet's refinement
157+
(dilation) relation: differentiating phi(t) = sqrt2 * sum_k h_k phi(2t - k)
158+
and evaluating at integers shows phi sampled at integers is the eigenvalue-1
159+
eigenvector and phi' the eigenvalue-1/2 eigenvector of T[p,q] = sqrt2 * h_{2p-q}.
160+
Normalizations come from reproduction of constants (sum phi(p) = 1) and of
161+
linears (sum p*phi'(p) = -1, so the operator differentiates a ramp exactly).
162+
163+
:return: - **Phi** (csc_matrix) -- circulant samples of phi, to be inverted
164+
- **Phi_prime** (csr_matrix) -- circulant samples of phi'/dt
159165
"""
160-
zero = np.zeros(N)
161-
template = pywt.wavedec(zero, wavelet, level=level, mode=mode)
162-
coeff_lengths = tuple(len(c) for c in template)
163-
coeff_offsets = np.cumsum((0,) + coeff_lengths[:-1])
164-
n_coeffs = sum(coeff_lengths)
165-
t = np.arange(N, dtype=float) * dt
166-
167-
rows, cols, vals = [], [], []
168-
eps = 1e-12
169-
170-
for band, (offset, length) in enumerate(zip(coeff_offsets, coeff_lengths)):
171-
for local_idx in range(length):
172-
coeffs = [np.zeros_like(c, dtype=float) for c in template]
173-
coeffs[band][local_idx] = 1.0
174-
basis = pywt.waverec(coeffs, wavelet, mode=mode)[:N]
175-
176-
# Basis functions are compactly supported, but boundary extension can
177-
# split support across the two ends. Differentiating only the active
178-
# samples keeps the matrix sparse and avoids global sinusoidal bases.
179-
active = np.flatnonzero(np.abs(basis) > eps)
180-
if active.size == 0:
181-
continue
182-
183-
# Include one-sample padding around active support so the cubic has
184-
# enough context near the edges of the support. If support wraps or
185-
# covers most of the signal, fall back to all samples.
186-
support = np.zeros(N, dtype=bool)
187-
support[active] = True
188-
support[np.maximum(active - 1, 0)] = True
189-
support[np.minimum(active + 1, N - 1)] = True
190-
idx = np.flatnonzero(support)
191-
if idx.size < 4 or (idx[-1] - idx[0] + 1) > 2 * idx.size:
192-
idx = np.arange(N)
193-
194-
# CubicSpline requires strictly increasing x and at least two points.
195-
# With >=4 points the not-a-knot default is well-defined; with fewer,
196-
# fall back to clamped end slopes of zero.
197-
bc_type = 'not-a-knot' if idx.size >= 4 else ((1, 0.0), (1, 0.0))
198-
spline = CubicSpline(t[idx], basis[idx], bc_type=bc_type, extrapolate=False)
199-
deriv_vals = spline(t[idx], 1)
200-
keep = np.isfinite(deriv_vals) & (np.abs(deriv_vals) > eps)
201-
202-
rows.extend(idx[keep])
203-
cols.extend(np.full(np.count_nonzero(keep), offset + local_idx))
204-
vals.extend(deriv_vals[keep])
205-
206-
return sparse.csr_matrix((vals, (rows, cols)), shape=(N, n_coeffs)), coeff_lengths
207-
208-
209-
def _flatten_wavelet_coeffs(coeffs):
210-
"""Stack a wavedec coefficient list into a 2-D coefficient matrix."""
211-
return np.vstack([c for band in coeffs for c in band])
212-
213-
214-
def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='periodization'):
215-
"""Smooth and differentiate noisy data with a wavelet-basis derivative sum.
216-
217-
Decomposes x into wavelet approximation/detail coefficients, soft-thresholds
218-
the detail coefficients to denoise, reconstructs a smoothed signal, and then
219-
estimates the derivative directly from the denoised wavelet coefficients:
220-
221-
x(t_n) = sum_k c_k phi_k(t_n)
222-
x'(t_n) = sum_k c_k phi'_k(t_n)
223-
224-
The first sum is the ordinary inverse wavelet transform. The second sum is
225-
evaluated by precomputing sparse samples of the derivative of each synthesis
226-
basis function and multiplying that sparse matrix by the denoised
227-
coefficients. This avoids the previous reconstruct-then-FFT derivative path
228-
and does not call finite differences or np.gradient on the signal.
166+
h = np.array(pywt.Wavelet(wavelet).rec_lo) # reconstruction low-pass = refinement filter h_k
167+
h = h / h.sum() * np.sqrt(2) # enforce sum(h) = sqrt2, i.e. integral of phi is 1
168+
L = len(h) # phi is supported on [0, L-1]; sample those integers
169+
p = np.arange(L)
170+
171+
# Transition matrix T[p,q] = sqrt2 * h_{2p-q}; entries outside the filter are 0.
172+
cols = 2 * p[:, None] - p[None, :]
173+
T = np.where((cols >= 0) & (cols < L), np.sqrt(2) * h[np.clip(cols, 0, L - 1)], 0.0)
174+
175+
evals, evecs = np.linalg.eig(T)
176+
phi = np.real(evecs[:, np.argmin(np.abs(evals - 1.0))]) # phi(p): eigenvalue 1
177+
dphi = np.real(evecs[:, np.argmin(np.abs(evals - 0.5))]) # phi'(p): eigenvalue 1/2
178+
phi = phi / phi.sum() # sum_p phi(p) = 1
179+
dphi = dphi / np.dot(p, dphi) * (-1.0) # sum_p p*phi'(p) = -1
180+
181+
# Both kernels become circulant matrices under periodic boundaries; a common
182+
# shift of both cancels in Phi_prime @ Phi^-1, so the offset choice is cosmetic.
183+
def circulant(kernel):
184+
rows, cols, vals = [], [], []
185+
m = np.arange(N)
186+
for offset, val in zip(p, kernel):
187+
if abs(val) < 1e-12: continue
188+
rows.extend(m); cols.extend((m - offset) % N); vals.extend([val] * N)
189+
return sparse.csr_matrix((vals, (rows, cols)), shape=(N, N))
190+
191+
return circulant(phi).tocsc(), circulant(dphi / dt)
192+
193+
194+
def waveletdiff(x, dt, wavelet='db8', level=None, threshold=1.0, axis=0, mode='periodization'):
195+
"""Smooth and differentiate noisy data in a wavelet basis.
196+
197+
Three steps: (1) decompose x with the DWT and soft-threshold the detail
198+
coefficients to denoise (Donoho-Johnstone universal threshold), reconstructing
199+
a smoothed x_hat; (2) extend x_hat antisymmetrically so the periodic derivative
200+
operator stays accurate at the edges; (3) recover the scaling coefficients of
201+
x_hat and apply the analytic derivative of the wavelet basis to get the
202+
derivative. The derivative operator differentiates the basis functions
203+
themselves (see :func:`_wavelet_derivative_operator`) rather than
204+
finite-differencing the signal, so it is exact for signals the basis can represent.
229205
230206
Because the DWT requires uniform spacing, this method only accepts a scalar
231207
time step dt (not a vector of sample times). For non-uniformly sampled data,
232208
use :func:`rbfdiff` or :func:`splinediff` instead.
233209
234210
:param np.array x: data to differentiate. May be multidimensional; see :code:`axis`.
235211
:param float dt: uniform time step between samples.
236-
:param str wavelet: PyWavelets wavelet name, e.g. 'db4', 'sym4', 'coif2'.
212+
:param str wavelet: PyWavelets wavelet name. Must have a differentiable scaling
213+
function, so smoother wavelets give better derivatives: 'db8' (default) and
214+
'sym8' are best for noisy data; 'db4', 'sym4', and 'coif2' also work well.
237215
:param int level: decomposition depth. None (default) resolves to
238216
min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short signals.
239217
:param float threshold: soft-thresholding scale factor in [0, inf).
240218
:param int axis: axis along which to differentiate (default 0).
241-
:param str mode: PyWavelets signal extension mode passed to wavedec/waverec.
242-
'periodization' keeps coefficient arrays compact; 'reflect' is often a
243-
better choice for clearly non-periodic signals.
219+
:param str mode: PyWavelets signal extension mode for the denoising transform.
220+
'periodization' keeps coefficient arrays compact. The derivative operator is
221+
periodic, so x_hat is antisymmetrically extended before it is applied (see below).
244222
:return: - **x_hat** (np.array) -- estimated (smoothed) x
245223
- **dxdt_hat** (np.array) -- estimated derivative of x
246224
"""
247225
if not np.isscalar(dt):
248-
raise ValueError(
249-
"`dt` must be a scalar. The DWT requires uniformly sampled data. "
250-
"For variable step sizes, use rbfdiff or splinediff instead."
251-
)
226+
raise ValueError("`dt` must be a scalar. The DWT requires uniformly sampled data. "
227+
"For variable step sizes, use rbfdiff or splinediff instead.")
228+
229+
# The Haar scaling function is a step, so it has no pointwise derivative and the
230+
# connection-coefficient operator below is undefined for it. Haar/db1 is the only
231+
# orthonormal wavelet with a 2-tap filter, so dec_len identifies it.
232+
if pywt.Wavelet(wavelet).dec_len == 2:
233+
raise ValueError("The Haar/db1 wavelet has a discontinuous (piecewise-constant) scaling "
234+
"function with no derivative, so it cannot be used to differentiate. Pick a smoother "
235+
"wavelet such as 'db4', 'sym4', or 'coif2'.")
252236

253237
N = x.shape[axis]
254-
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0))
255-
shape = x_work.shape
256-
x_flat = x_work.reshape(N, -1)
257-
M = x_flat.shape[1]
238+
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0)) # differentiation axis to front
239+
shape = x_work.shape # remember it to restore the input's dimensionality
240+
x_flat = x_work.reshape(N, -1) # rest of the dims flattened into columns
258241

259242
if level is None:
260-
max_level = pywt.dwt_max_level(N, wavelet)
261-
level = min(max_level, 5)
262-
263-
# Decompose all columns; probe column 0 first to learn coefficient lengths
264-
# and pre-allocate, reusing that result so we only pay N+1 wavedec calls.
265-
_probe = pywt.wavedec(x_flat[:, 0], wavelet, level=level, mode=mode)
266-
coeff_lengths = [len(c) for c in _probe]
267-
n_levels = len(_probe)
268-
269-
coeffs_all = [
270-
np.empty((coeff_lengths[i], M), dtype=x_flat.dtype)
271-
for i in range(n_levels)
272-
]
273-
for i, c in enumerate(_probe):
274-
coeffs_all[i][:, 0] = c
275-
276-
for col in range(1, M):
277-
for i, c in enumerate(
278-
pywt.wavedec(x_flat[:, col], wavelet, level=level, mode=mode)
279-
):
280-
coeffs_all[i][:, col] = c
281-
282-
# Robust noise estimate from finest details, then Donoho-Johnstone
283-
# soft-thresholding on detail bands only.
284-
sigma = np.median(np.abs(coeffs_all[-1]), axis=0) / 0.6745
285-
np.maximum(sigma, 1e-10, out=sigma)
286-
thresh = threshold * sigma * np.sqrt(2 * np.log(N))
287-
coeffs_denoised = [coeffs_all[0]] + [
288-
pywt.threshold(c, thresh[np.newaxis, :], mode='soft')
289-
for c in coeffs_all[1:]
290-
]
291-
292-
Dphi, matrix_coeff_lengths = _wavelet_derivative_synthesis_matrix(
293-
N, float(dt), wavelet, int(level), mode
294-
)
295-
if tuple(coeff_lengths) != tuple(matrix_coeff_lengths):
296-
raise RuntimeError("Cached wavelet derivative matrix coefficient layout does not match wavedec output.")
297-
298-
x_hat_flat = np.empty_like(x_flat)
299-
coeffs_flat = np.empty((sum(coeff_lengths), M), dtype=x_flat.dtype)
300-
offsets = np.cumsum((0,) + tuple(coeff_lengths[:-1]))
301-
302-
for col in range(M):
303-
col_coeffs = [coeffs_denoised[i][:, col] for i in range(n_levels)]
304-
x_hat_flat[:, col] = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
305-
for i, (offset, length) in enumerate(zip(offsets, coeff_lengths)):
306-
coeffs_flat[offset:offset + length, col] = coeffs_denoised[i][:, col]
307-
308-
dxdt_hat_flat = Dphi @ coeffs_flat
309-
310-
x_hat = np.moveaxis(x_hat_flat.reshape(shape), 0, axis)
311-
dxdt_hat = np.moveaxis(dxdt_hat_flat.reshape(shape), 0, axis)
243+
level = min(pywt.dwt_max_level(N, wavelet), 5)
312244

245+
# 1. Denoise: DWT all columns at once, then soft-threshold the detail bands. The
246+
# noise level is estimated robustly per column from the finest details (coeffs[-1]).
247+
coeffs = pywt.wavedec(x_flat, wavelet, level=level, mode=mode, axis=0)
248+
sigma = np.maximum(np.median(np.abs(coeffs[-1]), axis=0) / 0.6745, 1e-10)
249+
thresh = threshold * sigma * np.sqrt(2 * np.log(N))
250+
coeffs = [coeffs[0]] + [pywt.threshold(c, thresh[np.newaxis, :], mode='soft') for c in coeffs[1:]]
251+
x_hat = pywt.waverec(coeffs, wavelet, mode=mode, axis=0)[:N]
252+
253+
# 2. The derivative operator is periodic, but x_hat usually isn't. Extend it
254+
# antisymmetrically (reflect through each endpoint: x[-1-k] -> 2*x[0]-x[1+k]) so the
255+
# periodic wrap is continuous in both value and slope, which keeps the derivative
256+
# accurate at the edges instead of spiking there. This is the odd-symmetry analog of
257+
# spectraldiff's even extension; a ramp extends to a ramp, so slopes survive exactly.
258+
left = 2 * x_hat[0] - x_hat[1:][::-1]
259+
right = 2 * x_hat[-1] - x_hat[:-1][::-1]
260+
x_ext = np.concatenate([left, x_hat, right], axis=0) # length 3N-2, original at [N-1:2N-1]
261+
262+
# 3. Differentiate the basis: recover the scaling coefficients a = Phi^-1 @ x_ext, then
263+
# apply the analytic basis derivative dxdt = Phi_prime @ a, and crop back to the original.
264+
Phi, Phi_prime = _wavelet_derivative_operator(x_ext.shape[0], float(dt), wavelet)
265+
a = sparse.linalg.spsolve(Phi, x_ext)
266+
dxdt_flat = (Phi_prime @ a.reshape(x_ext.shape[0], -1))[N - 1:2 * N - 1]
267+
268+
x_hat = np.moveaxis(x_hat.reshape(shape), 0, axis)
269+
dxdt_hat = np.moveaxis(dxdt_flat.reshape(shape), 0, axis)
313270
return x_hat, dxdt_hat

pynumdiff/tests/test_diff_methods.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
5151
(spline_irreg_step, {'degree':5, 's':2}),
5252
(spectraldiff, {'high_freq_cutoff':0.2}), (spectraldiff, [0.2]),
5353
(rbfdiff, {'sigma':0.5, 'lmbd':0.001}),
54-
(waveletdiff, {'wavelet':'db4', 'threshold':1.0}),
54+
(waveletdiff, {'wavelet':'db8', 'threshold':1.0}),
5555
(constant_velocity, {'r':1e-2, 'q':1e3}), (constant_velocity, [1e-2, 1e3]),
5656
(constant_acceleration, {'r':1e-3, 'q':1e4}), (constant_acceleration, [1e-3, 1e4]),
5757
(constant_jerk, {'r':1e-4, 'q':1e5}), (constant_jerk, [1e-4, 1e5]),
@@ -174,12 +174,12 @@ def polydiff_irreg_step(*args, **kwargs): return polydiff(*args, **kwargs)
174174
[(-2, -2), (0, 0), (0, -1), (0, 0)],
175175
[(0, 0), (2, 2), (0, 0), (2, 2)],
176176
[(1, 1), (3, 3), (1, 1), (3, 3)]],
177-
waveletdiff: [[(-14, -15), (-14, -14), (-1, -1), (0, 0)],
178-
[(-9, -9), (-8, -8), (0, 0), (1, 1)],
179-
[(-9, -9), (0, 0), (0, 0), (1, 1)],
180-
[(-1, -1), (0, 0), (0, 0), (1, 1)],
181-
[(1, 0), (2, 2), (1, 1), (2, 2)],
182-
[(0, 0), (3, 3), (1, 0), (3, 3)]],
177+
waveletdiff: [[(-15, -15), (-13, -14), (0, -1), (1, 0)],
178+
[(-2, -2), (-1, -1), (0, 0), (1, 1)],
179+
[(-2, -2), (-1, -1), (0, 0), (1, 1)],
180+
[(-3, -3), (-1, -1), (0, 0), (1, 1)],
181+
[(0, -1), (2, 2), (0, 0), (2, 2)],
182+
[(0, -1), (3, 3), (0, 0), (3, 3)]],
183183
velocity: [[(-25, -25), (-18, -19), (0, -1), (1, 0)],
184184
[(-12, -12), (-11, -12), (-1, -1), (-1, -2)],
185185
[(0, -1), (1, 0), (0, -1), (1, 0)],
@@ -334,7 +334,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
334334
(finitediff, {}),
335335
(polydiff, {'degree': 2, 'window_size': 5}),
336336
(savgoldiff, {'degree': 3, 'window_size': 11, 'smoothing_win': 3}),
337-
(waveletdiff, {'wavelet': 'db4', 'threshold': 1.0}),
337+
(waveletdiff, {'wavelet': 'db8', 'threshold': 1.0}),
338338
(rtsdiff, {'order':2, 'log_qr_ratio':7, 'forwardbackward':True}),
339339
(spectraldiff, {'high_freq_cutoff': 0.25, 'pad_to_zero_dxdt': False}),
340340
(rbfdiff, {'sigma': 0.5, 'lmbd': 1e-6}),
@@ -351,7 +351,7 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
351351
kerneldiff: [(2, 1), (3, 2)],
352352
butterdiff: [(0, -1), (1, -1)],
353353
finitediff: [(0, -1), (1, -1)],
354-
waveletdiff: [(1, 0), (2, 1)],
354+
waveletdiff: [(1, 0), (2, 2)],
355355
polydiff: [(1, -1), (1, 0)],
356356
savgoldiff: [(0, -1), (1, 1)],
357357
rtsdiff: [(1, -1), (1, 0)],

0 commit comments

Comments
 (0)