|
127 | 127 | # Test dictionary format for merges with string IDs |
128 | 128 | curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}} |
129 | 129 |
|
130 | | - |
131 | 130 | # This is a failure example with duplicated merge |
132 | 131 | duplicate_merge = curation_ids_int.copy() |
133 | 132 | duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] |
@@ -292,11 +291,117 @@ def test_apply_curation_with_split(): |
292 | 291 | assert analyzer_curated.sorting.get_property("pyramidal", ids=[unit_id])[0] |
293 | 292 |
|
294 | 293 |
|
| 294 | +def test_apply_curation_with_split_multi_segment(): |
| 295 | + recording, sorting = generate_ground_truth_recording(durations=[10.0, 10.0], num_units=9, seed=2205) |
| 296 | + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) |
| 297 | + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) |
| 298 | + num_segments = sorting.get_num_segments() |
| 299 | + |
| 300 | + curation_with_splits_multi_segment = curation_with_splits.copy() |
| 301 | + |
| 302 | + # we make a split so that each subsplit will have all spikes from different segments |
| 303 | + split_unit_id = curation_with_splits_multi_segment["splits"][0]["unit_id"] |
| 304 | + sv = sorting.to_spike_vector() |
| 305 | + unit_index = sorting.id_to_index(split_unit_id) |
| 306 | + spikes_from_split_unit = sv[sv["unit_index"] == unit_index] |
| 307 | + |
| 308 | + split_indices = [] |
| 309 | + cum_spikes = 0 |
| 310 | + for segment_index in range(num_segments): |
| 311 | + spikes_in_segment = spikes_from_split_unit[spikes_from_split_unit["segment_index"] == segment_index] |
| 312 | + split_indices.append(np.arange(0, len(spikes_in_segment)) + cum_spikes) |
| 313 | + cum_spikes += len(spikes_in_segment) |
| 314 | + |
| 315 | + curation_with_splits_multi_segment["splits"][0]["split_indices"] = split_indices |
| 316 | + |
| 317 | + sorting_curated = apply_curation(sorting, curation_with_splits_multi_segment) |
| 318 | + |
| 319 | + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 1 |
| 320 | + assert 2 not in sorting_curated.unit_ids |
| 321 | + assert 43 in sorting_curated.unit_ids |
| 322 | + assert 44 in sorting_curated.unit_ids |
| 323 | + |
| 324 | + # check that spike trains are correctly split across segments |
| 325 | + for seg_index in range(num_segments): |
| 326 | + st_43 = sorting_curated.get_unit_spike_train(43, segment_index=seg_index) |
| 327 | + st_44 = sorting_curated.get_unit_spike_train(44, segment_index=seg_index) |
| 328 | + if seg_index == 0: |
| 329 | + assert len(st_43) > 0 |
| 330 | + assert len(st_44) == 0 |
| 331 | + else: |
| 332 | + assert len(st_43) == 0 |
| 333 | + assert len(st_44) > 0 |
| 334 | + |
| 335 | + |
| 336 | +def test_apply_curation_splits_with_mask(): |
| 337 | + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) |
| 338 | + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) |
| 339 | + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) |
| 340 | + |
| 341 | + # Get number of spikes for unit 2 |
| 342 | + num_spikes = sorting.count_num_spikes_per_unit()[2] |
| 343 | + |
| 344 | + # Create split labels that assign spikes to 3 different clusters |
| 345 | + split_labels = np.zeros(num_spikes, dtype=int) |
| 346 | + split_labels[: num_spikes // 3] = 0 # First third to cluster 0 |
| 347 | + split_labels[num_spikes // 3 : 2 * num_spikes // 3] = 1 # Second third to cluster 1 |
| 348 | + split_labels[2 * num_spikes // 3 :] = 2 # Last third to cluster 2 |
| 349 | + |
| 350 | + curation_with_mask_split = { |
| 351 | + "format_version": "2", |
| 352 | + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], |
| 353 | + "label_definitions": { |
| 354 | + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, |
| 355 | + "putative_type": { |
| 356 | + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], |
| 357 | + "exclusive": False, |
| 358 | + }, |
| 359 | + }, |
| 360 | + "manual_labels": [ |
| 361 | + {"unit_id": 2, "quality": ["good"], "putative_type": ["excitatory", "pyramidal"]}, |
| 362 | + ], |
| 363 | + "splits": [ |
| 364 | + { |
| 365 | + "unit_id": 2, |
| 366 | + "split_mode": "labels", |
| 367 | + "split_labels": split_labels.tolist(), |
| 368 | + "split_new_unit_ids": [43, 44, 45], |
| 369 | + } |
| 370 | + ], |
| 371 | + } |
| 372 | + |
| 373 | + sorting_curated = apply_curation(sorting, curation_with_mask_split) |
| 374 | + |
| 375 | + # Check results |
| 376 | + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 # Original units - 1 (split) + 3 (new) |
| 377 | + assert 2 not in sorting_curated.unit_ids # Original unit should be removed |
| 378 | + |
| 379 | + # Check new split units |
| 380 | + split_unit_ids = [43, 44, 45] |
| 381 | + for unit_id in split_unit_ids: |
| 382 | + assert unit_id in sorting_curated.unit_ids |
| 383 | + # Check properties are propagated |
| 384 | + assert sorting_curated.get_property("quality", ids=[unit_id])[0] == "good" |
| 385 | + assert sorting_curated.get_property("excitatory", ids=[unit_id])[0] |
| 386 | + assert sorting_curated.get_property("pyramidal", ids=[unit_id])[0] |
| 387 | + |
| 388 | + # Check analyzer |
| 389 | + analyzer_curated = apply_curation(analyzer, curation_with_mask_split) |
| 390 | + assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2 |
| 391 | + |
| 392 | + # Verify split sizes |
| 393 | + spike_counts = analyzer_curated.sorting.count_num_spikes_per_unit() |
| 394 | + assert spike_counts[43] == num_spikes // 3 # First third |
| 395 | + assert spike_counts[44] == num_spikes // 3 # Second third |
| 396 | + assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder |
| 397 | + |
| 398 | + |
295 | 399 | if __name__ == "__main__": |
296 | | - # test_curation_format_validation() |
297 | | - # test_to_from_json() |
298 | | - # test_convert_from_sortingview_curation_format_v0() |
299 | | - # test_curation_label_to_vectors() |
300 | | - # test_curation_label_to_dataframe() |
301 | | - # test_apply_curation() |
302 | | - test_apply_curation_with_split() |
| 400 | + test_curation_format_validation() |
| 401 | + test_to_from_json() |
| 402 | + test_convert_from_sortingview_curation_format_v0() |
| 403 | + test_curation_label_to_vectors() |
| 404 | + test_curation_label_to_dataframe() |
| 405 | + test_apply_curation() |
| 406 | + test_apply_curation_with_split_multi_segment() |
| 407 | + test_apply_curation_splits_with_mask() |
0 commit comments