Skip to content

Commit f1f3acb

Browse files
Improve stability of velocity fits in template metrics (#4342)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dd063c6 commit f1f3acb

1 file changed

Lines changed: 25 additions & 15 deletions

File tree

src/spikeinterface/metrics/template/metrics.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -284,21 +284,27 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y"
284284
return template[:, sort_indices], channel_locations[sort_indices, :]
285285

286286

287-
def fit_velocity(peak_times, channel_dist):
287+
def fit_line_robust(x, y, eps=1e-12):
288288
"""
289-
Fit velocity from peak times and channel distances using robust Theilsen estimator.
289+
Fit line using robust Theil-Sen estimator (median of pairwise slopes).
290290
"""
291-
# from scipy.stats import linregress
292-
# slope, intercept, _, _, _ = linregress(peak_times, channel_dist)
291+
import itertools
293292

294-
from sklearn.linear_model import TheilSenRegressor
293+
# Calculate slope and bias using Theil-Sen estimator
294+
slopes = []
295+
for (x0, y0), (x1, y1) in itertools.combinations(zip(x, y), 2):
296+
if np.abs(x1 - x0) > eps:
297+
slopes.append((y1 - y0) / (x1 - x0))
298+
if len(slopes) == 0: # all x are identical
299+
return np.nan, -np.inf
300+
slope = np.median(slopes)
301+
bias = np.median(y - slope * x)
295302

296-
theil = TheilSenRegressor()
297-
theil.fit(peak_times.reshape(-1, 1), channel_dist)
298-
slope = theil.coef_[0]
299-
intercept = theil.intercept_
300-
score = theil.score(peak_times.reshape(-1, 1), channel_dist)
301-
return slope, intercept, score
303+
# Calculate R2 score
304+
y_pred = slope * x + bias
305+
r2_score = 1 - ((y - y_pred) ** 2).sum() / (((y - y.mean()) ** 2).sum() + eps)
306+
307+
return slope, r2_score
302308

303309

304310
def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs):
@@ -354,8 +360,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
354360
channel_locations_above = channel_locations[channels_above]
355361
peak_times_ms_above = np.argmin(template_above, 0) / sampling_frequency * 1000 - max_peak_time
356362
distances_um_above = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_above])
357-
velocity_above, _, score = fit_velocity(peak_times_ms_above, distances_um_above)
358-
if score < min_r2:
363+
inv_velocity_above, score = fit_line_robust(distances_um_above, peak_times_ms_above)
364+
if score > min_r2 and inv_velocity_above != 0:
365+
velocity_above = 1 / inv_velocity_above
366+
else:
359367
velocity_above = np.nan
360368

361369
# Compute velocity below
@@ -367,8 +375,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
367375
channel_locations_below = channel_locations[channels_below]
368376
peak_times_ms_below = np.argmin(template_below, 0) / sampling_frequency * 1000 - max_peak_time
369377
distances_um_below = np.array([np.linalg.norm(cl - max_channel_location) for cl in channel_locations_below])
370-
velocity_below, _, score = fit_velocity(peak_times_ms_below, distances_um_below)
371-
if score < min_r2:
378+
inv_velocity_below, score = fit_line_robust(distances_um_below, peak_times_ms_below)
379+
if score > min_r2 and inv_velocity_below != 0:
380+
velocity_below = 1 / inv_velocity_below
381+
else:
372382
velocity_below = np.nan
373383

374384
return velocity_above, velocity_below

0 commit comments

Comments
 (0)