@@ -286,9 +286,8 @@ def test_sorting_analyzer_get_durations_from_recording(self, time_vector_recordi
286286 """
287287 _ , times_recording , _ = time_vector_recording
288288
289- sorting = si .generate_sorting (
290- durations = [times_recording .get_duration (s ) for s in range (times_recording .get_num_segments ())]
291- )
289+ durations = [times_recording .get_duration (s ) for s in range (times_recording .get_num_segments ())]
290+ sorting = si .generate_sorting (durations = durations )
292291 sorting_analyzer = si .create_sorting_analyzer (sorting , recording = times_recording )
293292
294293 assert np .array_equal (sorting_analyzer .get_total_duration (), times_recording .get_total_duration ())
@@ -484,10 +483,51 @@ def test_get_end_time_is_last_spike(self):
484483 assert sorting .get_end_time (segment_index = 0 ) == expected_time
485484
486485 def test_get_start_time_with_t_start (self ):
487- sorting = generate_sorting (num_units = 5 , durations = [10 ])
488- sorting .segments [0 ]._t_start = 100.0
486+ sorting = generate_sorting (num_units = 5 , durations = [10 ], t_starts = [100.0 ])
489487 assert sorting .get_start_time (segment_index = 0 ) == 100.0
490488
489+ def test_shift_times (self ):
490+ sorting = generate_sorting (num_units = 5 , durations = [10 ])
491+ unit_id = sorting .unit_ids [0 ]
492+
493+ spike_times_before = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True )
494+
495+ sorting .shift_times (shift = 5.0 )
496+
497+ assert sorting .get_start_time (segment_index = 0 ) == 5.0
498+ spike_times_after = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True )
499+ assert np .allclose (spike_times_after , spike_times_before + 5.0 )
500+
501+ def test_shift_times_all_segments (self ):
502+ sorting = generate_sorting (num_units = 5 , durations = [10 , 15 ], t_starts = [1.0 , 2.0 ])
503+
504+ sorting .shift_times (shift = 3.0 )
505+
506+ assert sorting .get_start_time (segment_index = 0 ) == 4.0
507+ assert sorting .get_start_time (segment_index = 1 ) == 5.0
508+
509+ def test_shift_times_single_segment (self ):
510+ sorting = generate_sorting (num_units = 5 , durations = [10 , 15 ], t_starts = [1.0 , 2.0 ])
511+
512+ sorting .shift_times (shift = 3.0 , segment_index = 1 )
513+
514+ assert sorting .get_start_time (segment_index = 0 ) == 1.0
515+ assert sorting .get_start_time (segment_index = 1 ) == 5.0
516+
517+ def test_shift_times_with_native_spike_times (self ):
518+ """Shift must apply even when the segment provides native spike times (e.g. NWB extractors)."""
519+ sorting = generate_sorting (num_units = 5 , durations = [10 ])
520+ unit_id = sorting .unit_ids [0 ]
521+ segment = sorting .segments [0 ]
522+
523+ # Simulate a segment that provides native spike times directly
524+ original_times = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True ).copy ()
525+ segment .get_unit_spike_train_in_seconds = lambda unit_id , start_time , end_time : original_times
526+
527+ sorting .shift_times (shift = 5.0 )
528+ spike_times = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True )
529+ assert np .allclose (spike_times , original_times + 5.0 )
530+
491531
492532class TestSortingTimeWithRecording :
493533 """
@@ -504,17 +544,16 @@ def test_get_start_end_time(self):
504544 assert sorting .get_end_time (segment_index = 0 ) == recording .get_end_time (segment_index = 0 )
505545
506546 def test_register_recording_copies_start_times (self ):
507- """Registering a recording copies its start times into the sorting segments."""
508- sorting = generate_sorting (num_units = 5 , durations = [10 ])
509- sorting .segments [0 ]._t_start = 100.0
547+ """Registering a recording overrides any pre-existing sorting start time."""
548+ sorting = generate_sorting (num_units = 5 , durations = [10 ], t_starts = [100.0 ])
510549
511550 recording = generate_recording (num_channels = 4 , durations = [10 ])
512551 recording .shift_times (shift = 50.0 )
513552 sorting .register_recording (recording )
514553
515- # _t_start now mirrors the recording's start time, preserving it across
516- # save/load cycles even when the recording is not attached .
517- assert sorting .segments [ 0 ]. _t_start == recording .get_start_time (segment_index = 0 )
554+ # The sorting's start time now mirrors the recording's start time, preserving it
555+ # across save/load cycles even when the recording is later detached .
556+ assert sorting .get_start_time ( segment_index = 0 ) == recording .get_start_time (segment_index = 0 )
518557 assert sorting .get_start_time (segment_index = 0 ) == 50.0
519558
520559 def test_with_recording_shifted_start (self ):
@@ -526,3 +565,68 @@ def test_with_recording_shifted_start(self):
526565 sorting .register_recording (recording )
527566
528567 assert sorting .get_start_time (segment_index = 0 ) == 50.0
568+
569+ def test_shift_times (self ):
570+ recording = generate_recording (num_channels = 4 , durations = [10 ])
571+ sorting = generate_sorting (num_units = 5 , durations = [10 ])
572+ sorting .register_recording (recording )
573+ unit_id = sorting .unit_ids [0 ]
574+
575+ rec_start_before = recording .get_start_time (segment_index = 0 )
576+ rec_end_before = recording .get_end_time (segment_index = 0 )
577+ spike_times_before = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True )
578+
579+ sorting .shift_times (shift = 5.0 )
580+
581+ # The recording should be untouched
582+ assert recording .get_start_time (segment_index = 0 ) == rec_start_before
583+ assert recording .get_end_time (segment_index = 0 ) == rec_end_before
584+
585+ # The sorting's times should be shifted
586+ assert sorting .get_start_time (segment_index = 0 ) == rec_start_before + 5.0
587+ assert sorting .get_end_time (segment_index = 0 ) == rec_end_before + 5.0
588+ spike_times_after = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True )
589+ assert np .allclose (spike_times_after , spike_times_before + 5.0 )
590+
591+ def test_time_conversion_roundtrip_after_shift (self ):
592+ """sample_index_to_time and time_to_sample_index must remain inverses after a shift."""
593+ recording = generate_recording (num_channels = 4 , durations = [10 ])
594+ sorting = generate_sorting (num_units = 5 , durations = [10 ])
595+ sorting .register_recording (recording )
596+
597+ sorting .shift_times (shift = 5.0 )
598+
599+ # Frame 30000 is 1.0s in the recording. After a 5.0s shift, the sorting should report 6.0s.
600+ time = sorting .sample_index_to_time (30000 , segment_index = 0 )
601+ assert time == recording .sample_index_to_time (30000 , segment_index = 0 ) + 5.0
602+
603+ # The inverse: 6.0s in the sorting should map back to frame 30000.
604+ frame = sorting .time_to_sample_index (time , segment_index = 0 )
605+ assert frame == 30000
606+
607+ def test_shift_times_with_time_vector (self ):
608+ """Shift on sorting composes with a recording that has an explicit time vector,
609+ preserving the irregular spacing."""
610+ recording = generate_recording (num_channels = 4 , durations = [1.0 ])
611+ num_samples = recording .get_num_samples (segment_index = 0 )
612+ # Irregular timestamps starting at 100.0
613+ times = (
614+ 100.0
615+ + np .cumsum (np .random .RandomState (0 ).uniform (0.5 , 1.5 , num_samples )) / recording .get_sampling_frequency ()
616+ )
617+ recording .set_times (times , segment_index = 0 , with_warning = False )
618+
619+ sorting = generate_sorting (num_units = 5 , durations = [1.0 ])
620+ sorting .register_recording (recording )
621+ unit_id = sorting .unit_ids [0 ]
622+
623+ spike_times_before = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True )
624+
625+ sorting .shift_times (shift = 5.0 )
626+
627+ spike_times_after = sorting .get_unit_spike_train (unit_id , segment_index = 0 , return_times = True )
628+ # Irregular spacing preserved, everything shifted by 5.0
629+ assert np .allclose (spike_times_after , spike_times_before + 5.0 )
630+
631+ # Recording is untouched
632+ assert np .allclose (recording .get_times (segment_index = 0 ), times )
0 commit comments