Skip to content

Commit 62d680b

Browse files
committed
added window_size as a parameter to jerk_sliding, because that makes more sense
1 parent 9bb86a6 commit 62d680b

3 files changed

Lines changed: 21 additions & 23 deletions

File tree

pynumdiff/tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def store_plots(request):
1212
request.config.plots = defaultdict(lambda: pyplot.subplots(6, 2, figsize=(12,7))) # 6 is len(test_funcs_and_derivs)
1313

1414
def pytest_sessionfinish(session, exitstatus):
15+
if not hasattr(session.config, 'plots'): return
1516
for method,(fig,axes) in session.config.plots.items():
1617
axes[-1,-1].legend()
1718
fig.suptitle(method.__name__)

pynumdiff/tests/test_diff_methods.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
# Function aliases for testing cases where parameters change the behavior in a big way
1212
def iterated_first_order(*args, **kwargs): return first_order(*args, **kwargs)
1313

14+
dt = 0.1
1415
t = np.linspace(0, 3, 31) # sample locations, including the endpoint
1516
tt = np.linspace(0, 3) # full domain, for visualizing denser plots
16-
ttt = np.linspace(0, 3, 201) # for testing jerk_sliding, which requires more points
1717
np.random.seed(7) # for repeatability of the test, so we don't get random failures
1818
noise = 0.05*np.random.randn(*t.shape)
19-
long_noise = 0.05*np.random.randn(*ttt.shape)
2019

2120
# Analytic (function, derivative) pairs on which to test differentiation methods.
2221
test_funcs_and_derivs = [
@@ -53,7 +52,7 @@ def iterated_first_order(*args, **kwargs): return first_order(*args, **kwargs)
5352
(jerk, {'gamma':10}), (jerk, [10]),
5453
(iterative_velocity, {'num_iterations':5, 'gamma':0.05}), (iterative_velocity, [5, 0.05]),
5554
(smooth_acceleration, {'gamma':2, 'window_size':5}), (smooth_acceleration, [2, 5]),
56-
(jerk_sliding, {'gamma':1e2, 'solver':'CLARABEL'}), (jerk_sliding, [1e2], {'solver':'CLARABEL'})
55+
(jerk_sliding, {'gamma':1, 'window_size':15}), (jerk_sliding, [1], {'window_size':15})
5756
]
5857

5958
# All the testing methodology follows the exact same pattern; the only thing that changes is the
@@ -186,11 +185,11 @@ def iterated_first_order(*args, **kwargs): return first_order(*args, **kwargs)
186185
[(0, 0), (1, 0), (0, -1), (1, 0)],
187186
[(1, 1), (2, 2), (1, 1), (2, 2)],
188187
[(1, 1), (3, 3), (1, 1), (3, 3)]],
189-
jerk_sliding: [[(-13, -14), (-12, -13), (0, -1), (1, 0)],
190-
[(-12, -13), (-12, -12), (0, -1), (0, 0)],
191-
[(-13, -14), (-12, -13), (0, -1), (0, 0)],
192-
[(-1, -2), (0, 0), (0, -1), (1, 0)],
193-
[(0, 0), (2, 1), (0, 0), (2, 1)],
188+
jerk_sliding: [[(-15, -15), (-16, -17), (0, -1), (1, 0)],
189+
[(-14, -14), (-14, -14), (0, -1), (0, 0)],
190+
[(-14, -14), (-14, -14), (0, -1), (0, 0)],
191+
[(-1, -1), (0, 0), (0, -1), (0, 0)],
192+
[(0, 0), (2, 2), (0, 0), (2, 2)],
194193
[(1, 1), (3, 3), (1, 1), (3, 3)]]
195194
}
196195

@@ -210,16 +209,9 @@ def test_diff_method(diff_method_and_params, test_func_and_deriv, request): # re
210209
except: warn(f"Cannot import cvxpy, skipping {diff_method} test."); return
211210

212211
# sample the true function and make noisy samples, and sample true derivative
213-
if diff_method != jerk_sliding:
214-
x = f(t)
215-
x_noisy = x + noise
216-
dxdt = df(t)
217-
dt = t[1] - t[0]
218-
else: # different density for jerk_sliding
219-
x = f(ttt)
220-
x_noisy = x + long_noise
221-
dxdt = df(ttt)
222-
dt = ttt[1] - ttt[0]
212+
x = f(t)
213+
x_noisy = x + noise
214+
dxdt = df(t)
223215

224216
# differentiate without and with noise, accounting for new and old styles of calling functions
225217
x_hat, dxdt_hat = diff_method(x, dt, **params) if isinstance(params, dict) \

pynumdiff/total_variation_regularization/_total_variation_regularization.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def smooth_acceleration(x, dt, params=None, options=None, gamma=None, window_siz
229229
return x_hat, dxdt_hat
230230

231231

232-
def jerk_sliding(x, dt, params=None, options=None, gamma=None, solver=None):
232+
def jerk_sliding(x, dt, params=None, options=None, gamma=None, solver=None, window_size=101):
233233
"""Use convex optimization (cvxpy) to solve for the jerk total variation regularized derivative.
234234
235235
:param np.array[float] x: data to differentiate
@@ -239,6 +239,7 @@ def jerk_sliding(x, dt, params=None, options=None, gamma=None, solver=None):
239239
:param float gamma: the regularization parameter
240240
:param str solver: the solver CVXPY should use, 'MOSEK', 'CVXOPT', 'CLARABEL', 'ECOS', etc.
241241
In testing, 'MOSEK' was the most robust. If not given, fall back to CVXPY's default.
242+
:param int window_size: how wide to make the kernel
242243
243244
:return: tuple[np.array, np.array] of\n
244245
- **x_hat** -- estimated (smoothed) x
@@ -250,12 +251,16 @@ def jerk_sliding(x, dt, params=None, options=None, gamma=None, solver=None):
250251
gamma = params[0] if isinstance(params, list) else params
251252
if options != None:
252253
if 'solver' in options: solver = options['solver']
254+
if 'window_size' in options: window_size = options['window_size']
253255
elif gamma == None:
254256
raise ValueError("`gamma` must be given.")
255257

256-
if len(x) <= 100:
257-
warn("len(x) <= 1000, calling standard jerk() without sliding")
258+
if len(x) < window_size:
259+
warn("len(x) <= window_size, calling standard jerk() without sliding")
258260
return _total_variation_regularized_derivative(x, dt, 3, gamma, solver=solver)
259261

260-
kernel = np.hstack((np.arange(1, 21)/20, np.ones(60), (np.arange(0, 21)/20)[::-1]))
261-
return utility.slide_function(_total_variation_regularized_derivative, x, dt, kernel, 3, gamma, stride=20, solver=solver)
262+
ramp = window_size//5
263+
kernel = np.hstack((np.arange(1, ramp+1)/ramp, np.ones(window_size - 2*ramp), (np.arange(1, ramp+1)/ramp)[::-1]))
264+
return utility.slide_function(_total_variation_regularized_derivative, x, dt, kernel, 3, gamma, stride=ramp, solver=solver)
265+
266+

0 commit comments

Comments
 (0)