Skip to content

Commit cb7fb5b

Browse files
committed
more efficient slide function
1 parent e03f5f8 commit cb7fb5b

1 file changed

Lines changed: 17 additions & 16 deletions

File tree

pynumdiff/utils/utility.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -181,26 +181,27 @@ def slide_function(func, x, dt, kernel, *args, stride=1, pass_weights=False, **k
181181
if len(kernel) % 2 == 0: raise ValueError("Kernel window size should be odd.")
182182
half_window_size = (len(kernel) - 1)//2 # int because len(kernel) is always odd
183183

184-
weights = np.zeros((int(np.ceil(len(x)/stride)), len(x))) # Could be more space efficient
185-
x_hats = np.zeros(weights.shape)
186-
dxdt_hats = np.zeros(weights.shape)
184+
x_hat = np.zeros(x.shape)
185+
dxdt_hat = np.zeros(x.shape)
186+
weight_sum = np.zeros(x.shape)
187187

188188
for i,midpoint in enumerate(range(0, len(x), stride)): # iterate window midpoints
189189
# find where to index data and kernel, taking care at edges
190-
window = slice(max(0, midpoint - half_window_size),
191-
min(len(x), midpoint + half_window_size + 1)) # +1 because slicing is exclusive of end
192-
kslice = slice(max(0, half_window_size - midpoint),
193-
min(len(kernel), len(kernel) - (midpoint + half_window_size + 1 - len(x))))
190+
start = max(0, midpoint - half_window_size)
191+
end = min(len(x), midpoint + half_window_size + 1) # +1 because slicing is exclusive of end
192+
window = slice(start, end)
194193

195-
# weights need to be renormalized if running off an edge
196-
weights[i, window] = kernel if kslice.stop - kslice.stop == len(kernel) else kernel[kslice]/np.sum(kernel[kslice])
197-
if pass_weights: kwargs['weights'] = weights[i, window]
194+
kstart = max(0, half_window_size - midpoint)
195+
kend = kstart + (end - start)
196+
kslice = slice(kstart, kend)
198197

199-
# run the function on the window and save results
200-
x_hats[i,window], dxdt_hats[i,window] = func(x[window], dt, *args, **kwargs)
198+
w = kernel if (end-start) == len(kernel) else kernel[kslice]/np.sum(kernel[kslice])
199+
if pass_weights: kwargs['weights'] = w
201200

202-
weights /= weights.sum(axis=0, keepdims=True) # normalize the weights
203-
x_hat = np.sum(weights*x_hats, axis=0)
204-
dxdt_hat = np.sum(weights*dxdt_hats, axis=0)
201+
# run the function on the window and add weighted results to cumulative answers
202+
x_window_hat, dxdt_window_hat = func(x[window], dt, *args, **kwargs)
203+
x_hat[window] += w * x_window_hat
204+
dxdt_hat[window] += w * dxdt_window_hat
205+
weight_sum[window] += w # save sum of weights for normalization at the end
205206

206-
return x_hat, dxdt_hat
207+
return x_hat/weight_sum, dxdt_hat/weight_sum

0 commit comments

Comments
 (0)