Skip to content

Commit 68f9255

Browse files
committed
Fix waveletdiff: use np.gradient on denoised signal for derivative
1 parent af06d01 commit 68f9255

1 file changed

Lines changed: 10 additions & 26 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -240,40 +240,24 @@ def waveletdiff(x, dt, wavelet='db4', level=None, threshold=1.0, axis=0, mode='p
240240
for c in coeffs_all[1:]
241241
]
242242

243-
# Build derivative reconstruction filters from the wavelet's reconstruction
244-
# filters. Because the DWT reconstructs a signal as a linear combination of
245-
# shifted scaling/wavelet functions, the derivative of the reconstruction is
246-
# the same linear combination of the *derivatives* of those basis functions.
247-
# We obtain derivative filters by finite-differencing the reconstruction
248-
# lowpass filter (rec_lo), then scaling by 1/dt to convert discrete
249-
# differences to continuous-time derivatives.
250-
w = pywt.Wavelet(wavelet)
251-
rec_lo = np.array(w.rec_lo)
252-
# First-order finite difference of the filter gives the derivative filter.
253-
# np.diff shortens by 1; padding with a leading zero keeps the filter length
254-
# and phase consistent with the original so waverec alignment is preserved.
255-
d_rec_lo = np.concatenate(([0.0], np.diff(rec_lo))) / dt
256-
d_rec_hi = np.concatenate(([0.0], np.diff(np.array(w.rec_hi)))) / dt
257-
258-
# Reconstruct x_hat and dxdt_hat column by column.
243+
# Reconstruct x_hat and differentiate column by column.
259244
# pywt.waverec is 1-D only, so the column loop is unavoidable here;
260245
# the vectorised operations above have already moved all Python-level
261246
# arithmetic outside this loop.
247+
#
248+
# After wavelet denoising we have a smooth, noise-free signal. np.gradient
249+
# applies a second-order central finite difference to that clean signal,
250+
# which gives an accurate derivative. This is appropriate here because the
251+
# heavy lifting (noise removal) has already been done by the wavelet
252+
# thresholding step; np.gradient on a smooth signal converges at O(dt^2).
262253
x_hat_flat = np.empty_like(x_flat)
263254
dxdt_hat_flat = np.empty_like(x_flat)
264255

265256
for col in range(M):
266257
col_coeffs = [coeffs_denoised[i][:, col] for i in range(n_levels)]
267-
268-
# Standard reconstruction for the smoothed signal.
269-
x_hat_flat[:, col] = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
270-
271-
# Derivative reconstruction: replace the wavelet's reconstruction
272-
# filters with their finite-difference derivatives and run waverec.
273-
d_wavelet = pywt.Wavelet(
274-
filter_bank=(w.dec_lo, w.dec_hi, d_rec_lo, d_rec_hi)
275-
)
276-
dxdt_hat_flat[:, col] = pywt.waverec(col_coeffs, d_wavelet, mode=mode)[:N]
258+
x_hat_col = pywt.waverec(col_coeffs, wavelet, mode=mode)[:N]
259+
x_hat_flat[:, col] = x_hat_col
260+
dxdt_hat_flat[:, col] = np.gradient(x_hat_col, dt)
277261

278262
# Restore original shape and axis order.
279263
x_hat = np.moveaxis(x_hat_flat.reshape(shape), 0, axis)

0 commit comments

Comments
 (0)