@@ -284,21 +284,27 @@ def sort_template_and_locations(template, channel_locations, depth_direction="y"
284284 return template [:, sort_indices ], channel_locations [sort_indices , :]
285285
286286
287- def fit_velocity ( peak_times , channel_dist ):
287+ def fit_line_robust ( x , y , eps = 1e-12 ):
288288 """
289- Fit velocity from peak times and channel distances using robust Theilsen estimator .
289+ Fit line using robust Theil-Sen estimator (median of pairwise slopes) .
290290 """
291- # from scipy.stats import linregress
292- # slope, intercept, _, _, _ = linregress(peak_times, channel_dist)
291+ import itertools
293292
294- from sklearn .linear_model import TheilSenRegressor
293+ # Calculate slope and bias using Theil-Sen estimator
294+ slopes = []
295+ for (x0 , y0 ), (x1 , y1 ) in itertools .combinations (zip (x , y ), 2 ):
296+ if np .abs (x1 - x0 ) > eps :
297+ slopes .append ((y1 - y0 ) / (x1 - x0 ))
298+ if len (slopes ) == 0 : # all x are identical
299+ return np .nan , - np .inf
300+ slope = np .median (slopes )
301+ bias = np .median (y - slope * x )
295302
296- theil = TheilSenRegressor ()
297- theil .fit (peak_times .reshape (- 1 , 1 ), channel_dist )
298- slope = theil .coef_ [0 ]
299- intercept = theil .intercept_
300- score = theil .score (peak_times .reshape (- 1 , 1 ), channel_dist )
301- return slope , intercept , score
303+ # Calculate R2 score
304+ y_pred = slope * x + bias
305+ r2_score = 1 - ((y - y_pred ) ** 2 ).sum () / (((y - y .mean ()) ** 2 ).sum () + eps )
306+
307+ return slope , r2_score
302308
303309
304310def get_velocity_fits (template , channel_locations , sampling_frequency , ** kwargs ):
@@ -354,8 +360,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
354360 channel_locations_above = channel_locations [channels_above ]
355361 peak_times_ms_above = np .argmin (template_above , 0 ) / sampling_frequency * 1000 - max_peak_time
356362 distances_um_above = np .array ([np .linalg .norm (cl - max_channel_location ) for cl in channel_locations_above ])
357- velocity_above , _ , score = fit_velocity (peak_times_ms_above , distances_um_above )
358- if score < min_r2 :
363+ inv_velocity_above , score = fit_line_robust (distances_um_above , peak_times_ms_above )
364+ if score > min_r2 and inv_velocity_above != 0 :
365+ velocity_above = 1 / inv_velocity_above
366+ else :
359367 velocity_above = np .nan
360368
361369 # Compute velocity below
@@ -367,8 +375,10 @@ def get_velocity_fits(template, channel_locations, sampling_frequency, **kwargs)
367375 channel_locations_below = channel_locations [channels_below ]
368376 peak_times_ms_below = np .argmin (template_below , 0 ) / sampling_frequency * 1000 - max_peak_time
369377 distances_um_below = np .array ([np .linalg .norm (cl - max_channel_location ) for cl in channel_locations_below ])
370- velocity_below , _ , score = fit_velocity (peak_times_ms_below , distances_um_below )
371- if score < min_r2 :
378+ inv_velocity_below , score = fit_line_robust (distances_um_below , peak_times_ms_below )
379+ if score > min_r2 and inv_velocity_below != 0 :
380+ velocity_below = 1 / inv_velocity_below
381+ else :
372382 velocity_below = np .nan
373383
374384 return velocity_above , velocity_below
0 commit comments