@@ -34,7 +34,6 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
3434 raise ValueError ("`high_freq_cutoff` must be given." )
3535
3636 L = x .shape [axis ]
37- x = np .asarray (x )
3837
3938 # Make derivative go to zero at the ends (optional)
4039 if pad_to_zero_dxdt :
@@ -44,11 +43,11 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
4443 x = np .concatenate ((pre , x , post ), axis = axis ) # extend the edges
4544 kernel = utility .mean_kernel (padding // 2 )
4645 x_smoothed = utility .convolutional_smoother (x , kernel , axis = axis ) # smooth the padded edges in
47- center = (slice (None ),)* axis + (slice (padding , L + padding ),) + (slice (None ),)* (x .ndim - axis - 1 )
48- x_smoothed [center ] = x [center ] # restore original signal in the middle
46+ m = (slice (None ),)* axis + (slice (padding , L + padding ),) + (slice (None ),)* (x .ndim - axis - 1 ) # middle
47+ x_smoothed [m ] = x [m ] # restore original signal in the middle
4948 x = x_smoothed
5049 else :
51- padding = 0
50+ m = ( slice ( None ),) * axis + ( slice ( 0 , L ),) + ( slice ( None ),) * ( x . ndim - axis - 1 ) # indices where signal lives
5251
5352 # Do even extension (optional)
5453 if even_extension is True :
@@ -63,21 +62,21 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
6362
6463 # Filter to zero out higher wavenumbers
6564 discrete_cutoff = int (high_freq_cutoff * N / 2 ) # Nyquist is at N/2 location, and we're cutting off as a fraction of that
66- filt = np .ones (k .shape ) # start with all frequencies passing
67- filt [discrete_cutoff :- discrete_cutoff ] = 0 # zero out high-frequency components
65+ filt = np .ones (k .shape ) # start with all frequencies passing
66+ filt [discrete_cutoff :- discrete_cutoff ] = 0 # zero out high-frequency components
6867
6968 # Smoothed signal
7069 X = np .fft .fft (x , axis = axis )
7170 x_hat = np .real (np .fft .ifft (filt [s ] * X , axis = axis ))
7271
7372 # Derivative = 90 deg phase shift
7473 omega = 2 * np .pi / (dt * N ) # factor of 2pi/T turns wavenumbers into frequencies in radians/s
75- dxdt = np .real (np .fft .ifft (1j * k [s ] * omega * filt [s ] * X , axis = axis ))
74+ dxdt_hat = np .real (np .fft .ifft (1j * k [s ] * omega * filt [s ] * X , axis = axis ))
7675
77- return ( x_hat [center ], dxdt [ center ]) if pad_to_zero_dxdt else ( x_hat , dxdt )
76+ return x_hat [m ], dxdt_hat [ m ]
7877
7978
80- def rbfdiff (x , dt_or_t , sigma = 1 , lmbd = 0.01 ):
79+ def rbfdiff (x , dt_or_t , sigma = 1 , lmbd = 0.01 , axis = 0 ):
8180 """Find smoothed function and derivative estimates by fitting noisy data with radial-basis-functions. Naively,
8281 fill a matrix with basis function samples and solve a linear inverse problem against the data, but truncate tiny
8382 values to make columns sparse. Each basis function "hill" is topped with a "tower" of height :code:`lmbd` to reach
@@ -88,17 +87,24 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
8887 :math:`\\ Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
8988 :param float sigma: controls width of radial basis functions
9089 :param float lmbd: controls smoothness
90+ :param int axis: data dimension along which differentiation is performed
9191
9292 :return: - **x_hat** (np.array) -- estimated (smoothed) x
9393 - **dxdt_hat** (np.array) -- estimated derivative of x
9494 """
95+ N = x .shape [axis ]
96+ x = np .moveaxis (x , axis , 0 ) # bring axis of differentiation to front so each N repeats comprise vector
97+ plump = x .shape
98+ x_flattened = x .reshape (N , - 1 ) # (N, M) matrix where each column is a vector along the original axis
99+
95100 if np .isscalar (dt_or_t ):
96- t = np .arange (len ( x ) )* dt_or_t
101+ t = np .arange (N )* dt_or_t
97102 else : # support variable step size for this function
98- if len ( x ) != len (dt_or_t ): raise ValueError ("If `dt_or_t` is given as array-like, must have same length as `x`." )
103+ if N != len (dt_or_t ): raise ValueError ("If `dt_or_t` is given as array-like, must have same length as `x`." )
99104 t = dt_or_t
100105
101- # The below does the approximate equivalent of this code, but sparsely in O(N sigma^2), since the rbf falls off rapidly
106+ # For each vector along the axis of differentiation, the below does the approximate equivalent of this code,
107+ # but sparsely in O(N sigma^2), since the rbf falls off rapidly
102108 # t_i, t_j = np.meshgrid(t,t)
103109 # r = t_j - t_i # radius
104110 # rbf = np.exp(-(r**2) / (2 * sigma**2)) # radial basis function kernel, O(N^2) entries
@@ -118,9 +124,12 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
118124 dv = - radius / sigma ** 2 * v # take derivative of radial basis function, because d/dt coef*f(t) = coef*df/dt
119125 rows .append (n ); cols .append (j ); vals .append (v ); dvals .append (dv )
120126
121- rbf = sparse .csr_matrix ((vals , (rows , cols )), shape = (len (t ), len (t ))) # Build sparse kernels, O(N sigma) entries
122- drbfdt = sparse .csr_matrix ((dvals , (rows , cols )), shape = (len (t ), len (t )))
123- rbf_regularized = rbf + lmbd * sparse .eye (len (t ), format = "csr" ) # identity matrix gives a little extra height at the centers
124- alpha = sparse .linalg .spsolve (rbf_regularized , x ) # solve sparse system targeting the noisy data, O(N sigma^2)
127+ rbf = sparse .csr_matrix ((vals , (rows , cols )), shape = (N , N )) # Build sparse kernels, O(N sigma) entries
128+ drbfdt = sparse .csr_matrix ((dvals , (rows , cols )), shape = (N , N ))
129+ rbf_regularized = rbf + lmbd * sparse .eye (N , format = "csr" ) # identity matrix gives a little extra height at the centers
130+ alpha = sparse .linalg .spsolve (rbf_regularized , x_flattened ) # solve sparse system targeting the noisy data,
131+ # can take matrix target, O(N sigma^2) for each vector
132+ x_hat_flattened = rbf @ alpha # find samples of reconstructions using the smooth bases
133+ dxdt_hat_flattened = drbfdt @ alpha
125134
126- return rbf @ alpha , drbfdt @ alpha # find samples of reconstructions using the smooth bases
135+ return np . moveaxis ( x_hat_flattened . reshape ( plump ), 0 , axis ), np . moveaxis ( dxdt_hat_flattened . reshape ( plump ), 0 , axis )
0 commit comments