Skip to content

Commit 8e8d46e

Browse files
pavelkomarovclaude
andcommitted
Inline the wavelet derivative operator into waveletdiff
Fold the private _wavelet_derivative_operator into waveletdiff as a standalone block, mirroring how rbfdiff builds its basis and derivative matrices in one loop. No other caller needs it, and the connection-coefficient math now lives in waveletdiff's own docstring. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
1 parent c8003a4 commit 8e8d46e

1 file changed

Lines changed: 36 additions & 63 deletions

File tree

pynumdiff/basis_fit.py

Lines changed: 36 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -136,73 +136,27 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01, axis=0):
136136
return np.moveaxis(x_hat_flattened.reshape(plump), 0, axis), np.moveaxis(dxdt_hat_flattened.reshape(plump), 0, axis)
137137

138138

139-
def _wavelet_derivative_operator(N, dt, wavelet):
140-
"""Build the sparse operators that turn denoised samples into a derivative.
141-
142-
Depends only on the grid (N, dt) and the wavelet, not on the data, and is
143-
built once per waveletdiff call then applied to every column at once.
144-
145-
PyWavelets treats the input samples as the finest-level scaling coefficients,
146-
so the denoised reconstruction x_hat represents the continuous interpolant
147-
x(t) = sum_n a_n phi(t/dt - n), where phi is the wavelet's scaling function.
148-
Sampling x and its analytic derivative on the grid t_m = m*dt gives two
149-
convolutions against phi and phi' evaluated at *integers*:
150-
151-
x_hat[m] = sum_n a_n phi(m - n) -> x_hat = Phi @ a
152-
x'(t_m) = (1/dt) sum_n a_n phi'(m - n) -> x' = Phi_prime @ a
153-
154-
so x' = Phi_prime @ Phi^-1 @ x_hat. This is the exact derivative of the
155-
wavelet interpolant, with no spline or finite-difference approximation.
156-
157-
phi and phi' at the integers are the eigenvectors of the wavelet's refinement
158-
(dilation) relation: differentiating phi(t) = sqrt2 * sum_k h_k phi(2t - k)
159-
and evaluating at integers shows phi sampled at integers is the eigenvalue-1
160-
eigenvector and phi' the eigenvalue-1/2 eigenvector of T[p,q] = sqrt2 * h_{2p-q}.
161-
Normalizations come from reproduction of constants (sum phi(p) = 1) and of
162-
linears (sum p*phi'(p) = -1, so the operator differentiates a ramp exactly).
163-
164-
:return: - **Phi** (csc_matrix) -- circulant samples of phi, to be inverted
165-
- **Phi_prime** (csr_matrix) -- circulant samples of phi'/dt
166-
"""
167-
h = np.array(pywt.Wavelet(wavelet).rec_lo) # reconstruction low-pass = refinement filter h_k
168-
h = h / h.sum() * np.sqrt(2) # enforce sum(h) = sqrt2, i.e. integral of phi is 1
169-
L = len(h) # phi is supported on [0, L-1]; sample those integers
170-
p = np.arange(L)
171-
172-
# Transition matrix T[p,q] = sqrt2 * h_{2p-q}; entries outside the filter are 0.
173-
cols = 2 * p[:, None] - p[None, :]
174-
T = np.where((cols >= 0) & (cols < L), np.sqrt(2) * h[np.clip(cols, 0, L - 1)], 0.0)
175-
176-
evals, evecs = np.linalg.eig(T)
177-
phi = np.real(evecs[:, np.argmin(np.abs(evals - 1.0))]) # phi(p): eigenvalue 1
178-
dphi = np.real(evecs[:, np.argmin(np.abs(evals - 0.5))]) # phi'(p): eigenvalue 1/2
179-
phi = phi / phi.sum() # sum_p phi(p) = 1
180-
dphi = dphi / np.dot(p, dphi) * (-1.0) # sum_p p*phi'(p) = -1
181-
182-
# Both kernels become circulant matrices under periodic boundaries; a common
183-
# shift of both cancels in Phi_prime @ Phi^-1, so the offset choice is cosmetic.
184-
def circulant(kernel):
185-
rows, cols, vals = [], [], []
186-
m = np.arange(N)
187-
for offset, val in zip(p, kernel):
188-
if abs(val) < 1e-12: continue
189-
rows.extend(m); cols.extend((m - offset) % N); vals.extend([val] * N)
190-
return sparse.csr_matrix((vals, (rows, cols)), shape=(N, N))
191-
192-
return circulant(phi).tocsc(), circulant(dphi / dt)
193-
194-
195139
def waveletdiff(x, dt, wavelet='db8', level=None, threshold=1.0, axis=0, mode='periodization'):
196140
"""Smooth and differentiate noisy data in a wavelet basis.
197141
198142
Three steps: (1) decompose x with the DWT and soft-threshold the detail
199143
coefficients to denoise (Donoho-Johnstone universal threshold), reconstructing
200144
a smoothed x_hat; (2) extend x_hat antisymmetrically so the periodic derivative
201-
operator stays accurate at the edges; (3) recover the scaling coefficients of
202-
x_hat and apply the analytic derivative of the wavelet basis to get the
203-
derivative. The derivative operator differentiates the basis functions
204-
themselves (see :func:`_wavelet_derivative_operator`) rather than
205-
finite-differencing the signal, so it is exact for signals the basis can represent.
145+
operator stays accurate at the edges; (3) recover the wavelet scaling
146+
coefficients of x_hat and apply the analytic derivative of the wavelet basis.
147+
148+
The derivative differentiates the basis functions themselves rather than
149+
finite-differencing the signal. PyWavelets treats the samples as finest-level
150+
scaling coefficients, so x_hat is the interpolant x(t) = sum_n a_n phi(t/dt - n)
151+
for the scaling function phi. Sampling x and its analytic derivative on the grid
152+
gives two convolutions against phi and phi' evaluated at *integers*,
153+
154+
x_hat = Phi @ a and x' = Phi_prime @ a,
155+
156+
so x' = Phi_prime @ Phi^-1 @ x_hat, exact for signals the basis can represent.
157+
The integer samples phi(p), phi'(p) are the eigenvalue-1 and eigenvalue-1/2
158+
eigenvectors of the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k)
159+
(the "connection coefficients"), normalized to reproduce constants and ramps.
206160
207161
Because the DWT requires uniform spacing, this method only accepts a scalar
208162
time step dt (not a vector of sample times). For non-uniformly sampled data,
@@ -239,6 +193,26 @@ def waveletdiff(x, dt, wavelet='db8', level=None, threshold=1.0, axis=0, mode='p
239193
x_work = np.ascontiguousarray(np.moveaxis(x, axis, 0)) # differentiation axis to front
240194
shape = x_work.shape # remember it to restore the input's dimensionality
241195
x_flat = x_work.reshape(N, -1) # rest of the dims flattened into columns
196+
Ne = 3 * N - 2 # length after the antisymmetric extension in step 2
197+
198+
# Build the wavelet-basis derivative operator (depends only on the grid and wavelet).
199+
# Sampling the refinement relation phi(t) = sqrt2 sum_k h_k phi(2t - k) at integers makes
200+
# phi(p) the eigenvalue-1 and phi'(p) the eigenvalue-1/2 eigenvector of T[p,q] = sqrt2 h_{2p-q}.
201+
h = np.array(pywt.Wavelet(wavelet).rec_lo); h = h / h.sum() * np.sqrt(2) # refinement filter, integral of phi = 1
202+
L = len(h); p = np.arange(L) # phi is supported on the integers [0, L-1]
203+
shift = 2 * p[:, None] - p[None, :]
204+
T = np.where((shift >= 0) & (shift < L), np.sqrt(2) * h[np.clip(shift, 0, L - 1)], 0.0)
205+
evals, evecs = np.linalg.eig(T)
206+
phi = np.real(evecs[:, np.argmin(np.abs(evals - 1.0))]); phi /= phi.sum() # sum_p phi(p) = 1
207+
dphi = np.real(evecs[:, np.argmin(np.abs(evals - 0.5))]); dphi /= np.dot(p, dphi)*-1 # sum_p p*phi'(p) = -1
208+
# Phi and Phi_prime hold circulant samples of phi and phi'/dt on the extended grid; both
209+
# share a common shift that cancels in Phi_prime @ Phi^-1, so the offset choice is cosmetic.
210+
rows, cols, phi_vals, dphi_vals = [], [], [], []
211+
m = np.arange(Ne)
212+
for offset, phi_p, dphi_p in zip(p, phi, dphi / dt):
213+
rows.extend(m); cols.extend((m - offset) % Ne); phi_vals.extend([phi_p]*Ne); dphi_vals.extend([dphi_p]*Ne)
214+
Phi = sparse.csr_matrix((phi_vals, (rows, cols)), shape=(Ne, Ne)).tocsc() # to invert
215+
Phi_prime = sparse.csr_matrix((dphi_vals, (rows, cols)), shape=(Ne, Ne)) # to apply
242216

243217
if level is None:
244218
level = min(pywt.dwt_max_level(N, wavelet), 5)
@@ -262,9 +236,8 @@ def waveletdiff(x, dt, wavelet='db8', level=None, threshold=1.0, axis=0, mode='p
262236

263237
# 3. Differentiate the basis: recover the scaling coefficients a = Phi^-1 @ x_ext, then
264238
# apply the analytic basis derivative dxdt = Phi_prime @ a, and crop back to the original.
265-
Phi, Phi_prime = _wavelet_derivative_operator(x_ext.shape[0], float(dt), wavelet)
266239
a = sparse.linalg.spsolve(Phi, x_ext)
267-
dxdt_flat = (Phi_prime @ a.reshape(x_ext.shape[0], -1))[N - 1:2 * N - 1]
240+
dxdt_flat = (Phi_prime @ a.reshape(Ne, -1))[N - 1:2 * N - 1]
268241

269242
x_hat = np.moveaxis(x_hat.reshape(shape), 0, axis)
270243
dxdt_hat = np.moveaxis(dxdt_flat.reshape(shape), 0, axis)

0 commit comments

Comments
 (0)