Skip to content

Commit 659ecff

Browse files
committed
Extend splitting-tests to multi-segment and mask labels
1 parent d4fa8bf commit 659ecff

2 files changed

Lines changed: 116 additions & 10 deletions

File tree

src/spikeinterface/core/sorting_tools.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,16 +469,17 @@ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extr
469469
spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices]
470470
spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True)
471471

472-
# split_indices are a concatenation across segments
473472
for unit_id in sorting.unit_ids:
474473
if unit_id in unit_splits:
475474
split_indices = unit_splits[unit_id]
476475
new_split_ids = new_unit_ids[list(unit_splits.keys()).index(unit_id)]
477476

478477
for split, new_unit_id in zip(split_indices, new_split_ids):
479478
new_unit_index = all_unit_ids.index(new_unit_id)
479+
# split_indices are a concatenation across segments with absolute indices
480+
# so we need to concatenate the spike indices across segments
480481
spike_indices_unit = np.concatenate(
481-
spike_indices[segment_index][unit_id] for segment_index in range(num_seg)
482+
[spike_indices[segment_index][unit_id] for segment_index in range(num_seg)]
482483
)
483484
spikes["unit_index"][spike_indices_unit[split]] = new_unit_index
484485
else:

src/spikeinterface/curation/tests/test_curation_format.py

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@
127127
# Test dictionary format for merges with string IDs
128128
curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}}
129129

130-
131130
# This is a failure example with duplicated merge
132131
duplicate_merge = curation_ids_int.copy()
133132
duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]]
@@ -292,11 +291,117 @@ def test_apply_curation_with_split():
292291
assert analyzer_curated.sorting.get_property("pyramidal", ids=[unit_id])[0]
293292

294293

294+
def test_apply_curation_with_split_multi_segment():
295+
recording, sorting = generate_ground_truth_recording(durations=[10.0, 10.0], num_units=9, seed=2205)
296+
sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]))
297+
analyzer = create_sorting_analyzer(sorting, recording, sparse=False)
298+
num_segments = sorting.get_num_segments()
299+
300+
curation_with_splits_multi_segment = curation_with_splits.copy()
301+
302+
# we make a split so that each subsplit will have all spikes from different segments
303+
split_unit_id = curation_with_splits_multi_segment["splits"][0]["unit_id"]
304+
sv = sorting.to_spike_vector()
305+
unit_index = sorting.id_to_index(split_unit_id)
306+
spikes_from_split_unit = sv[sv["unit_index"] == unit_index]
307+
308+
split_indices = []
309+
cum_spikes = 0
310+
for segment_index in range(num_segments):
311+
spikes_in_segment = spikes_from_split_unit[spikes_from_split_unit["segment_index"] == segment_index]
312+
split_indices.append(np.arange(0, len(spikes_in_segment)) + cum_spikes)
313+
cum_spikes += len(spikes_in_segment)
314+
315+
curation_with_splits_multi_segment["splits"][0]["split_indices"] = split_indices
316+
317+
sorting_curated = apply_curation(sorting, curation_with_splits_multi_segment)
318+
319+
assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 1
320+
assert 2 not in sorting_curated.unit_ids
321+
assert 43 in sorting_curated.unit_ids
322+
assert 44 in sorting_curated.unit_ids
323+
324+
# check that spike trains are correctly split across segments
325+
for seg_index in range(num_segments):
326+
st_43 = sorting_curated.get_unit_spike_train(43, segment_index=seg_index)
327+
st_44 = sorting_curated.get_unit_spike_train(44, segment_index=seg_index)
328+
if seg_index == 0:
329+
assert len(st_43) > 0
330+
assert len(st_44) == 0
331+
else:
332+
assert len(st_43) == 0
333+
assert len(st_44) > 0
334+
335+
336+
def test_apply_curation_splits_with_mask():
337+
recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205)
338+
sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]))
339+
analyzer = create_sorting_analyzer(sorting, recording, sparse=False)
340+
341+
# Get number of spikes for unit 2
342+
num_spikes = sorting.count_num_spikes_per_unit()[2]
343+
344+
# Create split labels that assign spikes to 3 different clusters
345+
split_labels = np.zeros(num_spikes, dtype=int)
346+
split_labels[: num_spikes // 3] = 0 # First third to cluster 0
347+
split_labels[num_spikes // 3 : 2 * num_spikes // 3] = 1 # Second third to cluster 1
348+
split_labels[2 * num_spikes // 3 :] = 2 # Last third to cluster 2
349+
350+
curation_with_mask_split = {
351+
"format_version": "2",
352+
"unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42],
353+
"label_definitions": {
354+
"quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True},
355+
"putative_type": {
356+
"label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"],
357+
"exclusive": False,
358+
},
359+
},
360+
"manual_labels": [
361+
{"unit_id": 2, "quality": ["good"], "putative_type": ["excitatory", "pyramidal"]},
362+
],
363+
"splits": [
364+
{
365+
"unit_id": 2,
366+
"split_mode": "labels",
367+
"split_labels": split_labels.tolist(),
368+
"split_new_unit_ids": [43, 44, 45],
369+
}
370+
],
371+
}
372+
373+
sorting_curated = apply_curation(sorting, curation_with_mask_split)
374+
375+
# Check results
376+
assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 # Original units - 1 (split) + 3 (new)
377+
assert 2 not in sorting_curated.unit_ids # Original unit should be removed
378+
379+
# Check new split units
380+
split_unit_ids = [43, 44, 45]
381+
for unit_id in split_unit_ids:
382+
assert unit_id in sorting_curated.unit_ids
383+
# Check properties are propagated
384+
assert sorting_curated.get_property("quality", ids=[unit_id])[0] == "good"
385+
assert sorting_curated.get_property("excitatory", ids=[unit_id])[0]
386+
assert sorting_curated.get_property("pyramidal", ids=[unit_id])[0]
387+
388+
# Check analyzer
389+
analyzer_curated = apply_curation(analyzer, curation_with_mask_split)
390+
assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2
391+
392+
# Verify split sizes
393+
spike_counts = analyzer_curated.sorting.count_num_spikes_per_unit()
394+
assert spike_counts[43] == num_spikes // 3 # First third
395+
assert spike_counts[44] == num_spikes // 3 # Second third
396+
assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder
397+
398+
295399
if __name__ == "__main__":
296-
# test_curation_format_validation()
297-
# test_to_from_json()
298-
# test_convert_from_sortingview_curation_format_v0()
299-
# test_curation_label_to_vectors()
300-
# test_curation_label_to_dataframe()
301-
# test_apply_curation()
302-
test_apply_curation_with_split()
400+
test_curation_format_validation()
401+
test_to_from_json()
402+
test_convert_from_sortingview_curation_format_v0()
403+
test_curation_label_to_vectors()
404+
test_curation_label_to_dataframe()
405+
test_apply_curation()
406+
test_apply_curation_with_split_multi_segment()
407+
test_apply_curation_splits_with_mask()

0 commit comments

Comments
 (0)