@@ -26,7 +26,12 @@ def test_common_reference(recording):
2626 rec_cmr_ref = common_reference (recording , reference = "global" , operator = "median" , ref_channel_ids = ["a" , "b" , "c" ])
2727 rec_car = common_reference (recording , reference = "global" , operator = "average" )
2828 rec_sin = common_reference (recording , reference = "single" , ref_channel_ids = ["a" ])
29- rec_local_car = common_reference (recording , reference = "local" , local_radius = (20 , 65 ), operator = "median" )
29+ rec_local_cmr = common_reference (
30+ recording , reference = "local" , local_radius = (25 , 65 ), operator = "median" , min_local_neighbors = 1
31+ )
32+ rec_local_car = common_reference (
33+ recording , reference = "local" , local_radius = (25 , 65 ), operator = "average" , min_local_neighbors = 1
34+ )
3035
3136 traces = recording .get_traces ()
3237 assert np .allclose (traces , rec_cmr .get_traces () + np .median (traces , axis = 1 , keepdims = True ), atol = 0.01 )
@@ -35,21 +40,26 @@ def test_common_reference(recording):
3540 assert not np .all (rec_sin .get_traces ()[0 ])
3641 assert np .allclose (rec_sin .get_traces ()[:, 1 ], traces [:, 1 ] - traces [:, 0 ])
3742
38- assert np .allclose (traces [:, 0 ], rec_local_car .get_traces ()[:, 0 ] + np .median (traces [:, [2 , 3 ]], axis = 1 ), atol = 0.01 )
39- assert np .allclose (traces [:, 1 ], rec_local_car .get_traces ()[:, 1 ] + np .median (traces [:, [3 ]], axis = 1 ), atol = 0.01 )
43+ assert np .allclose (traces [:, 0 ], rec_local_cmr .get_traces ()[:, 0 ] + np .median (traces [:, [2 , 3 ]], axis = 1 ), atol = 0.01 )
44+ assert np .allclose (traces [:, 1 ], rec_local_cmr .get_traces ()[:, 1 ] + np .median (traces [:, [3 ]], axis = 1 ), atol = 0.01 )
45+
46+ assert np .allclose (traces [:, 0 ], rec_local_car .get_traces ()[:, 0 ] + np .mean (traces [:, [2 , 3 ]], axis = 1 ), atol = 0.01 )
47+ assert np .allclose (traces [:, 1 ], rec_local_car .get_traces ()[:, 1 ] + np .mean (traces [:, [3 ]], axis = 1 ), atol = 0.01 )
4048
4149 # Saving tests
4250 rec_cmr .save (verbose = False )
4351 rec_car .save (verbose = False )
4452 rec_sin .save (verbose = False )
53+ rec_local_cmr .save (verbose = False )
4554 rec_local_car .save (verbose = False )
4655
4756
4857def test_common_reference_channel_slicing (recording ):
4958 recording_cmr = common_reference (recording , reference = "global" , operator = "median" )
5059 recording_car = common_reference (recording , reference = "global" , operator = "average" )
5160 recording_single_reference = common_reference (recording , reference = "single" , ref_channel_ids = ["b" ])
52- recording_local_car = common_reference (recording , reference = "local" , local_radius = (20 , 65 ), operator = "median" )
61+ recording_local_cmr = common_reference (recording , reference = "local" , local_radius = (20 , 65 ), operator = "median" )
62+ recording_local_car = common_reference (recording , reference = "local" , local_radius = (20 , 65 ), operator = "average" )
5363
5464 channel_ids = ["b" , "d" ]
5565 indices = recording .ids_to_indices (channel_ids )
@@ -73,9 +83,12 @@ def test_common_reference_channel_slicing(recording):
7383
7484 assert np .allclose (single_reference_trace , expected_trace , atol = 0.01 )
7585
86+ local_trace = recording_local_cmr .get_traces (channel_ids = all_channel_ids )
87+ local_trace_sub = recording_local_cmr .get_traces (channel_ids = channel_ids )
88+ assert np .all (local_trace [:, indices ] == local_trace_sub )
89+
7690 local_trace = recording_local_car .get_traces (channel_ids = all_channel_ids )
7791 local_trace_sub = recording_local_car .get_traces (channel_ids = channel_ids )
78-
7992 assert np .all (local_trace [:, indices ] == local_trace_sub )
8093
8194 # test segment slicing
@@ -157,8 +170,44 @@ def test_common_reference_groups(recording):
157170 assert np .allclose (traces [:, 1 ], 0 )
158171
159172
173+ def test_min_local_radius ():
174+ # Test that local radius smaller than the number of channels is handled correctly
175+ recording = generate_recording (durations = [1.0 ], num_channels = 32 )
176+ # remove closest channel to first channel
177+ recording = recording .remove_channels (recording .channel_ids [1 :5 ])
178+ with pytest .warns (UserWarning ):
179+ recording_local_car = common_reference (
180+ recording , reference = "local" , local_radius = (60 , 150 ), operator = "average" , min_local_neighbors = 5
181+ )
182+
183+
184+ @pytest .mark .skip (reason = "This test can be used to check local CAR vs local CMR performance" )
185+ def test_local_car_vs_cmr_performance ():
186+ import time
187+
188+ # Test that local CAR is faster than local CMR when there are many channels
189+ recording = generate_recording (durations = [10.0 ], num_channels = 384 )
190+
191+ rec_local_cmr = common_reference (
192+ recording , reference = "local" , local_radius = (20 , 65 ), operator = "median" , min_local_neighbors = 1
193+ )
194+ t_start_cmr = time .perf_counter ()
195+ rec_local_cmr .get_traces ()
196+ t_end_cmr = time .perf_counter ()
197+ cmr_time = t_end_cmr - t_start_cmr
198+
199+ rec_local_car = common_reference (
200+ recording , reference = "local" , local_radius = (20 , 65 ), operator = "average" , min_local_neighbors = 1
201+ )
202+ t_start_car = time .perf_counter ()
203+ rec_local_car .get_traces ()
204+ t_end_car = time .perf_counter ()
205+ car_time = t_end_car - t_start_car
206+
207+ print (f"Local CMR time: { cmr_time :.4f} seconds" )
208+ print (f"Local CAR time: { car_time :.4f} seconds" )
209+ assert car_time < cmr_time
210+
211+
160212if __name__ == "__main__" :
161- recording = _generate_test_recording ()
162- test_common_reference (recording )
163- test_common_reference_channel_slicing (recording )
164- test_common_reference_groups (recording )
213+ test_local_car_vs_cmr_performance ()
0 commit comments