11"""Methods based on fitting basis functions to data"""
2+ from functools import lru_cache
23from warnings import warn
34import numpy as np
45from scipy import sparse
6+ from scipy .interpolate import CubicSpline
57import pywt
68
79from pynumdiff .utils import utility
@@ -136,14 +138,94 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
136138 return np .moveaxis (x_hat_flattened .reshape (plump ), 0 , axis ), np .moveaxis (dxdt_hat_flattened .reshape (plump ), 0 , axis )
137139
138140
141+ @lru_cache (maxsize = 32 )
142+ def _wavelet_derivative_synthesis_matrix (N , dt , wavelet , level , mode ):
143+ """Build sparse samples of d/dt of the inverse-DWT synthesis basis.
144+
145+ For a fixed wavelet/level/mode/length, wavedec/waverec define a linear
146+ synthesis map
147+
148+ x(t_n) = sum_k c_k phi_k(t_n).
149+
150+ This routine samples phi'_k(t_n) once, stores those samples sparsely, and
151+ lets waveletdiff compute
152+
153+ x'(t_n) = sum_k c_k phi'_k(t_n)
154+
155+ without differentiating the reconstructed signal. The derivative samples
156+ are obtained from a local cubic interpolant of each compactly supported
157+ synthesis basis vector; this is bookkeeping on the basis functions, not a
158+ finite-difference derivative of the data.
159+ """
160+ zero = np .zeros (N )
161+ template = pywt .wavedec (zero , wavelet , level = level , mode = mode )
162+ coeff_lengths = tuple (len (c ) for c in template )
163+ coeff_offsets = np .cumsum ((0 ,) + coeff_lengths [:- 1 ])
164+ n_coeffs = sum (coeff_lengths )
165+ t = np .arange (N , dtype = float ) * dt
166+
167+ rows , cols , vals = [], [], []
168+ eps = 1e-12
169+
170+ for band , (offset , length ) in enumerate (zip (coeff_offsets , coeff_lengths )):
171+ for local_idx in range (length ):
172+ coeffs = [np .zeros_like (c , dtype = float ) for c in template ]
173+ coeffs [band ][local_idx ] = 1.0
174+ basis = pywt .waverec (coeffs , wavelet , mode = mode )[:N ]
175+
176+ # Basis functions are compactly supported, but boundary extension can
177+ # split support across the two ends. Differentiating only the active
178+ # samples keeps the matrix sparse and avoids global sinusoidal bases.
179+ active = np .flatnonzero (np .abs (basis ) > eps )
180+ if active .size == 0 :
181+ continue
182+
183+ # Include one-sample padding around active support so the cubic has
184+ # enough context near the edges of the support. If support wraps or
185+ # covers most of the signal, fall back to all samples.
186+ support = np .zeros (N , dtype = bool )
187+ support [active ] = True
188+ support [np .maximum (active - 1 , 0 )] = True
189+ support [np .minimum (active + 1 , N - 1 )] = True
190+ idx = np .flatnonzero (support )
191+ if idx .size < 4 or (idx [- 1 ] - idx [0 ] + 1 ) > 2 * idx .size :
192+ idx = np .arange (N )
193+
194+ # CubicSpline requires strictly increasing x and at least two points.
195+ # With >=4 points the not-a-knot default is well-defined; with fewer,
196+ # fall back to clamped end slopes of zero.
197+ bc_type = 'not-a-knot' if idx .size >= 4 else ((1 , 0.0 ), (1 , 0.0 ))
198+ spline = CubicSpline (t [idx ], basis [idx ], bc_type = bc_type , extrapolate = False )
199+ deriv_vals = spline (t [idx ], 1 )
200+ keep = np .isfinite (deriv_vals ) & (np .abs (deriv_vals ) > eps )
201+
202+ rows .extend (idx [keep ])
203+ cols .extend (np .full (np .count_nonzero (keep ), offset + local_idx ))
204+ vals .extend (deriv_vals [keep ])
205+
206+ return sparse .csr_matrix ((vals , (rows , cols )), shape = (N , n_coeffs )), coeff_lengths
207+
208+
209+ def _flatten_wavelet_coeffs (coeffs ):
210+ """Stack a wavedec coefficient list into a 2-D coefficient matrix."""
211+ return np .vstack ([c for band in coeffs for c in band ])
212+
213+
139214def waveletdiff (x , dt , wavelet = 'db4' , level = None , threshold = 1.0 , axis = 0 , mode = 'periodization' ):
140- """Smooth and differentiate noisy data via discrete wavelet denoising .
215+ """Smooth and differentiate noisy data with a wavelet-basis derivative sum .
141216
142- Decomposes x into wavelet detail and approximation coefficients, soft-thresholds
143- the detail coefficients to remove noise using the Donoho & Johnstone (1994)
144- universal threshold estimator, reconstructs a smoothed signal, then
145- differentiates analytically by applying derivative reconstruction filters to
146- the denoised wavelet coefficients.
217+ Decomposes x into wavelet approximation/detail coefficients, soft-thresholds
218+ the detail coefficients to denoise, reconstructs a smoothed signal, and then
219+ estimates the derivative directly from the denoised wavelet coefficients:
220+
221+ x(t_n) = sum_k c_k phi_k(t_n)
222+ x'(t_n) = sum_k c_k phi'_k(t_n)
223+
224+ The first sum is the ordinary inverse wavelet transform. The second sum is
225+ evaluated by precomputing sparse samples of the derivative of each synthesis
226+ basis function and multiplying that sparse matrix by the denoised
227+ coefficients. This avoids the previous reconstruct-then-FFT derivative path
228+ and does not call finite differences or np.gradient on the signal.
147229
148230 Because the DWT requires uniform spacing, this method only accepts a scalar
149231 time step dt (not a vector of sample times). For non-uniformly sampled data,
@@ -152,24 +234,13 @@ def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='p
152234 :param np.array x: data to differentiate. May be multidimensional; see :code:`axis`.
153235 :param float dt: uniform time step between samples.
154236 :param str wavelet: PyWavelets wavelet name, e.g. 'db4', 'sym4', 'coif2'.
155- 'db4' is a solid general-purpose default. Biorthogonal wavelets such as
156- 'bior2.2' or 'bior4.4' are symmetric and designed for smooth reconstruction
157- but may need a lower threshold value.
158237 :param int level: decomposition depth. None (default) resolves to
159- min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short
160- signals. Increase for heavily oversampled data.
238+ min(pywt.dwt_max_level(N, wavelet), 5) to avoid over-decomposing short signals.
161239 :param float threshold: soft-thresholding scale factor in [0, inf).
162- Multiplies the universal threshold sigma * sqrt(2 * log(N)).
163- threshold=1.0 is the classical Donoho & Johnstone universal threshold
164- and is the recommended starting point. Values < 1.0 give less smoothing;
165- values > 1.0 give more aggressive smoothing. This parameter maps onto
166- tvgamma in the pynumdiff.optimize framework.
167240 :param int axis: axis along which to differentiate (default 0).
168241 :param str mode: PyWavelets signal extension mode passed to wavedec/waverec.
169- 'periodization' (default) keeps coefficient arrays exactly length N and
170- is the most numerically stable choice for differentiation. 'reflect' is
171- a good alternative for clearly non-periodic signals.
172- See pywt.Modes.modes for all options.
242+ 'periodization' keeps coefficient arrays compact; 'reflect' is often a
243+ better choice for clearly non-periodic signals.
173244 :return: - **x_hat** (np.array) -- estimated (smoothed) x
174245 - **dxdt_hat** (np.array) -- estimated derivative of x
175246 """
@@ -180,19 +251,11 @@ def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='p
180251 )
181252
182253 N = x .shape [axis ]
183-
184- # Bring axis of differentiation to front so each column of x_flat is one
185- # signal to differentiate. moveaxis returns a view with updated strides,
186- # so ascontiguousarray ensures the subsequent reshape is zero-copy.
187254 x_work = np .ascontiguousarray (np .moveaxis (x , axis , 0 ))
188255 shape = x_work .shape
189- x_flat = x_work .reshape (N , - 1 ) # (N, M)
256+ x_flat = x_work .reshape (N , - 1 )
190257 M = x_flat .shape [1 ]
191258
192- # Conservative level cap: pywt's default uses the maximum possible level,
193- # which can over-decompose short signals and wash out meaningful detail.
194- # Capping at 5 keeps at least 2^5 = 32 samples in the coarsest subband,
195- # which is enough to represent a smooth approximation without artefacts.
196259 if level is None :
197260 max_level = pywt .dwt_max_level (N , wavelet )
198261 level = min (max_level , 5 )
@@ -216,57 +279,35 @@ def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='p
216279 ):
217280 coeffs_all [i ][:, col ] = c
218281
219- # Vectorised noise estimation and soft-thresholding over all columns at once.
220- #
221- # Soft-thresholding achieves smoothing by shrinking wavelet detail
222- # coefficients toward zero: coefficients whose magnitude is below the
223- # threshold (mostly noise) are zeroed out, while large coefficients (true
224- # signal features) are kept but reduced by the threshold amount. Only detail
225- # levels (indices 1..n_levels-1) are thresholded; the coarse approximation
226- # coefficients (index 0) are left untouched.
227- #
228- # sigma: robust noise-level estimate via the median absolute deviation of
229- # the finest detail level. Dividing by 0.6745 converts MAD to an
230- # estimate of the Gaussian standard deviation.
231- # thresh: per-column Donoho & Johnstone (1994) universal threshold,
232- # sigma * sqrt(2 * log(N)), scaled by the user-supplied `threshold`.
282+ # Robust noise estimate from finest details, then Donoho-Johnstone
283+ # soft-thresholding on detail bands only.
233284 sigma = np .median (np .abs (coeffs_all [- 1 ]), axis = 0 ) / 0.6745
234- np .maximum (sigma , 1e-10 , out = sigma ) # floor avoids zero threshold on clean signals
235-
236- thresh = threshold * sigma * np .sqrt (2 * np .log (N )) # shape (M,)
237-
285+ np .maximum (sigma , 1e-10 , out = sigma )
286+ thresh = threshold * sigma * np .sqrt (2 * np .log (N ))
238287 coeffs_denoised = [coeffs_all [0 ]] + [
239288 pywt .threshold (c , thresh [np .newaxis , :], mode = 'soft' )
240289 for c in coeffs_all [1 :]
241290 ]
242291
243- # Reconstruct x_hat and differentiate column by column.
244- # pywt.waverec is 1-D only, so the column loop is unavoidable here;
245- # the vectorised operations above have already moved all Python-level
246- # arithmetic outside this loop.
247- #
248- # After wavelet denoising we have a smooth, noise-free signal. We
249- # differentiate it analytically in the Fourier domain: multiplying the
250- # FFT by i*omega is equivalent to applying the derivative operator exactly,
251- # with no finite-difference truncation error. This keeps the two concerns
252- # cleanly separated — wavelets handle denoising, Fourier handles
253- # differentiation.
254- x_hat_flat = np .empty_like (x_flat )
255- dxdt_hat_flat = np .empty_like (x_flat )
256-
257- # Angular frequency axis for a length-N signal sampled at dt.
258- # fftfreq returns cycles/sample; multiplying by 2*pi/dt gives rad/s.
259- k = np .fft .fftfreq (N , d = dt ) * 2 * np .pi
292+ Dphi , matrix_coeff_lengths = _wavelet_derivative_synthesis_matrix (
293+ N , float (dt ), wavelet , int (level ), mode
294+ )
295+ if tuple (coeff_lengths ) != tuple (matrix_coeff_lengths ):
296+ raise RuntimeError ("Cached wavelet derivative matrix coefficient layout does not match wavedec output." )
297+
298+ x_hat_flat = np .empty_like (x_flat )
299+ coeffs_flat = np .empty ((sum (coeff_lengths ), M ), dtype = x_flat .dtype )
300+ offsets = np .cumsum ((0 ,) + tuple (coeff_lengths [:- 1 ]))
260301
261302 for col in range (M ):
262303 col_coeffs = [coeffs_denoised [i ][:, col ] for i in range (n_levels )]
263- x_hat_col = pywt .waverec (col_coeffs , wavelet , mode = mode )[:N ]
264- x_hat_flat [:, col ] = x_hat_col
265- X = np .fft .fft (x_hat_col )
266- dxdt_hat_flat [:, col ] = np .real (np .fft .ifft (1j * k * X ))
304+ x_hat_flat [:, col ] = pywt .waverec (col_coeffs , wavelet , mode = mode )[:N ]
305+ for i , (offset , length ) in enumerate (zip (offsets , coeff_lengths )):
306+ coeffs_flat [offset :offset + length , col ] = coeffs_denoised [i ][:, col ]
307+
308+ dxdt_hat_flat = Dphi @ coeffs_flat
267309
268- # Restore original shape and axis order.
269- x_hat = np .moveaxis (x_hat_flat .reshape (shape ), 0 , axis )
310+ x_hat = np .moveaxis (x_hat_flat .reshape (shape ), 0 , axis )
270311 dxdt_hat = np .moveaxis (dxdt_hat_flat .reshape (shape ), 0 , axis )
271312
272313 return x_hat , dxdt_hat
0 commit comments