@@ -74,7 +74,7 @@ def spectraldiff(x, dt, params=None, options=None, high_freq_cutoff=None, even_e
7474 return x_hat , dxdt_hat
7575
7676
77- def rbfdiff (x , dt_or_t , sigma = 1 , lmbd = 0.01 ):
77+ def rbfdiff (x , dt_or_t , sigma = 1 , lmbd = 0.01 , axis = 0 ):
7878 """Find smoothed function and derivative estimates by fitting noisy data with radial-basis-functions. Naively,
7979 fill a matrix with basis function samples and solve a linear inverse problem against the data, but truncate tiny
8080 values to make columns sparse. Each basis function "hill" is topped with a "tower" of height :code:`lmbd` to reach
@@ -85,14 +85,21 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
8585 :math:`\\ Delta t` if given as a single float, or data locations if given as an array of same length as :code:`x`.
8686 :param float sigma: controls width of radial basis functions
8787 :param float lmbd: controls smoothness
88+ :param int axis: data dimension along which differentiation is performed
8889
8990 :return: - **x_hat** (np.array) -- estimated (smoothed) x
9091 - **dxdt_hat** (np.array) -- estimated derivative of x
9192 """
93+ x = np .asarray (x )
94+ x = np .moveaxis (x , axis , 0 ) # bring target axis to front
95+ orig_shape = x .shape
96+ N = orig_shape [0 ]
97+ x_2d = x .reshape (N , - 1 ) # (N, M) — build matrix once, solve for all M columns
98+
9299 if np .isscalar (dt_or_t ):
93- t = np .arange (len ( x ) )* dt_or_t
100+ t = np .arange (N )* dt_or_t
94101 else : # support variable step size for this function
95- if len ( x ) != len (dt_or_t ): raise ValueError ("If `dt_or_t` is given as array-like, must have same length as `x`." )
102+ if N != len (dt_or_t ): raise ValueError ("If `dt_or_t` is given as array-like, must have same length as `x`." )
96103 t = dt_or_t
97104
98105 # The below does the approximate equivalent of this code, but sparsely in O(N sigma^2), since the rbf falls off rapidly
@@ -115,9 +122,11 @@ def rbfdiff(x, dt_or_t, sigma=1, lmbd=0.01):
115122 dv = - radius / sigma ** 2 * v # take derivative of radial basis function, because d/dt coef*f(t) = coef*df/dt
116123 rows .append (n ); cols .append (j ); vals .append (v ); dvals .append (dv )
117124
118- rbf = sparse .csr_matrix ((vals , (rows , cols )), shape = (len ( t ), len ( t ) )) # Build sparse kernels, O(N sigma) entries
119- drbfdt = sparse .csr_matrix ((dvals , (rows , cols )), shape = (len ( t ), len ( t ) ))
120- rbf_regularized = rbf + lmbd * sparse .eye (len ( t ) , format = "csr" ) # identity matrix gives a little extra height at the centers
121- alpha = sparse .linalg .spsolve (rbf_regularized , x ) # solve sparse system targeting the noisy data, O(N sigma^2)
125+ rbf = sparse .csr_matrix ((vals , (rows , cols )), shape = (N , N )) # Build sparse kernels, O(N sigma) entries
126+ drbfdt = sparse .csr_matrix ((dvals , (rows , cols )), shape = (N , N ))
127+ rbf_regularized = rbf + lmbd * sparse .eye (N , format = "csr" ) # identity matrix gives a little extra height at the centers
128+ alpha = sparse .linalg .spsolve (rbf_regularized , x_2d ) # solve sparse system targeting the noisy data, O(N sigma^2)
122129
123- return rbf @ alpha , drbfdt @ alpha # find samples of reconstructions using the smooth bases
130+ x_hat = np .moveaxis ((rbf @ alpha ).reshape (orig_shape ), 0 , axis ) # find samples of reconstructions using the smooth bases
131+ dxdt_hat = np .moveaxis ((drbfdt @ alpha ).reshape (orig_shape ), 0 , axis )
132+ return x_hat , dxdt_hat
0 commit comments