Skip to content

Commit 3dbfacf

Browse files
Support time_vector in resample() (#4429)
1 parent c08e933 commit 3dbfacf

2 files changed

Lines changed: 123 additions & 13 deletions

File tree

src/spikeinterface/preprocessing/resample.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ def __init__(
6565
margin = int(margin_ms * recording.get_sampling_frequency() / 1000)
6666

6767
BasePreprocessor.__init__(self, recording, sampling_frequency=resample_rate, dtype=dtype)
68-
# in case there was a time_vector, it will be dropped for sanity.
6968
for parent_segment in recording._recording_segments:
70-
parent_segment.time_vector = None
7169
self.add_recording_segment(
7270
ResampleRecordingSegment(
7371
parent_segment,
@@ -96,24 +94,44 @@ def __init__(
9694
margin,
9795
dtype,
9896
):
99-
# Do not use BasePreprocessorSegment bcause we have to reset the sampling rate!
100-
BaseRecordingSegment.__init__(
101-
self,
102-
sampling_frequency=resample_rate,
103-
t_start=parent_recording_segment.t_start,
104-
)
97+
self._resample_rate = resample_rate
10598
self._parent_segment = parent_recording_segment
10699
self._parent_rate = parent_rate
107100
self._margin = margin
108101
self._dtype = dtype
109102

103+
# Compute time_vector or t_start, following the pattern from DecimateRecordingSegment.
104+
# Do not use BasePreprocessorSegment because we have to reset the sampling rate!
105+
if parent_recording_segment.time_vector is not None:
106+
parent_tv = np.asarray(parent_recording_segment.time_vector)
107+
n_out = int(len(parent_tv) / parent_rate * resample_rate)
108+
109+
if parent_rate % resample_rate == 0:
110+
q_int = int(parent_rate / resample_rate)
111+
time_vector = parent_tv[::q_int][:n_out]
112+
else:
113+
warnings.warn(
114+
"Resampling with a non-integer ratio requires interpolating the time_vector. "
115+
"An integer ratio (parent_rate / resample_rate) is more performant."
116+
)
117+
parent_indices = np.linspace(0, len(parent_tv) - 1, n_out)
118+
time_vector = np.interp(parent_indices, np.arange(len(parent_tv)), parent_tv)
119+
120+
BaseRecordingSegment.__init__(self, sampling_frequency=None, t_start=None, time_vector=time_vector)
121+
else:
122+
BaseRecordingSegment.__init__(
123+
self, sampling_frequency=resample_rate, t_start=parent_recording_segment.t_start
124+
)
125+
110126
def get_num_samples(self):
111-
return int(self._parent_segment.get_num_samples() / self._parent_rate * self.sampling_frequency)
127+
if self.time_vector is not None:
128+
return len(self.time_vector)
129+
return int(self._parent_segment.get_num_samples() / self._parent_rate * self._resample_rate)
112130

113131
def get_traces(self, start_frame, end_frame, channel_indices):
114132
# get parent traces with margin
115133
parent_start_frame, parent_end_frame = [
116-
int((frame / self.sampling_frequency) * self._parent_rate) for frame in [start_frame, end_frame]
134+
int((frame / self._resample_rate) * self._parent_rate) for frame in [start_frame, end_frame]
117135
]
118136
parent_traces, left_margin, right_margin = get_chunk_with_margin(
119137
self._parent_segment,
@@ -126,7 +144,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
126144
)
127145
# get left and right margins for the resampled case
128146
left_margin_rs, right_margin_rs = [
129-
int((margin / self._parent_rate) * self.sampling_frequency) for margin in [left_margin, right_margin]
147+
int((margin / self._parent_rate) * self._resample_rate) for margin in [left_margin, right_margin]
130148
]
131149

132150
# get the size for the resampled traces in case of resample:
@@ -136,9 +154,9 @@ def get_traces(self, start_frame, end_frame, channel_indices):
136154
# Check which method to use:
137155
from scipy import signal
138156

139-
if np.mod(self._parent_rate, self.sampling_frequency) == 0:
157+
if np.mod(self._parent_rate, self._resample_rate) == 0:
140158
# Ratio between sampling frequencies
141-
q = int(self._parent_rate / self.sampling_frequency)
159+
q = int(self._parent_rate / self._resample_rate)
142160
# Decimate can have issues for some cases, returning NaNs
143161
resampled_traces = signal.decimate(parent_traces, q=q, axis=0)
144162
# If that's the case, use signal.resample

src/spikeinterface/preprocessing/tests/test_resample.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,98 @@ def test_resample_by_chunks():
210210
plt.show()
211211

212212

213+
def test_resample_preserves_t_start():
214+
"""Resampling should preserve t_start when the parent has one."""
215+
sampling_frequency = 30000
216+
t_start = 100.5
217+
traces = np.random.randn(sampling_frequency * 2, 2).astype(np.float32)
218+
parent_rec = NumpyRecording(traces, sampling_frequency)
219+
parent_rec._recording_segments[0].t_start = t_start
220+
221+
resampled = resample(parent_rec, 500)
222+
assert resampled._recording_segments[0].t_start == t_start
223+
assert not resampled.has_time_vector()
224+
assert np.isclose(resampled.get_times()[0], t_start)
225+
226+
227+
def test_resample_does_not_mutate_parent():
228+
"""Resampling should not modify the parent recording's time_vector."""
229+
sampling_frequency = 30000
230+
n_samples = sampling_frequency * 2
231+
traces = np.random.randn(n_samples, 2).astype(np.float32)
232+
parent_rec = NumpyRecording(traces, sampling_frequency)
233+
time_vector = np.arange(n_samples, dtype="float64") / sampling_frequency + 50.0
234+
parent_rec.set_times(time_vector)
235+
236+
assert parent_rec.has_time_vector()
237+
resample(parent_rec, 500)
238+
assert parent_rec.has_time_vector(), "Parent time_vector was mutated by resample!"
239+
np.testing.assert_array_equal(parent_rec.get_times(), time_vector)
240+
241+
242+
def test_resample_preserves_time_vector_integer_ratio():
243+
"""Resampling with integer ratio should slice the parent time_vector."""
244+
sampling_frequency = 30000
245+
resample_rate = 500
246+
n_samples = sampling_frequency * 2
247+
traces = np.random.randn(n_samples, 2).astype(np.float32)
248+
parent_rec = NumpyRecording(traces, sampling_frequency)
249+
250+
# Create a time_vector with a gap (simulating artifact removal)
251+
time_vector = np.arange(n_samples, dtype="float64") / sampling_frequency
252+
# Insert a 5-second gap at the midpoint
253+
midpoint = n_samples // 2
254+
time_vector[midpoint:] += 5.0
255+
parent_rec.set_times(time_vector)
256+
257+
resampled = resample(parent_rec, resample_rate)
258+
259+
assert resampled.has_time_vector()
260+
resampled_times = resampled.get_times()
261+
n_out = resampled.get_num_samples()
262+
263+
# Output length should be consistent
264+
assert len(resampled_times) == n_out
265+
266+
# The gap should be preserved: check that the jump exists in the resampled times
267+
diffs = np.diff(resampled_times)
268+
normal_dt = 1.0 / resample_rate
269+
gap_indices = np.where(diffs > normal_dt * 2)[0]
270+
assert len(gap_indices) == 1, "The gap should appear exactly once in resampled times"
271+
assert np.isclose(diffs[gap_indices[0]], normal_dt + 5.0, atol=normal_dt)
272+
273+
# Start time should match
274+
assert np.isclose(resampled_times[0], time_vector[0])
275+
276+
277+
def test_resample_preserves_time_vector_non_integer_ratio():
278+
"""Resampling with non-integer ratio should interpolate the time_vector."""
279+
sampling_frequency = 30000
280+
resample_rate = 700 # 30000 / 700 is not integer
281+
n_samples = sampling_frequency * 2
282+
traces = np.random.randn(n_samples, 2).astype(np.float32)
283+
parent_rec = NumpyRecording(traces, sampling_frequency)
284+
285+
time_vector = np.arange(n_samples, dtype="float64") / sampling_frequency + 10.0
286+
parent_rec.set_times(time_vector)
287+
288+
import warnings as _warnings
289+
290+
with _warnings.catch_warnings(record=True) as w:
291+
_warnings.simplefilter("always")
292+
resampled = resample(parent_rec, resample_rate)
293+
assert any("non-integer ratio" in str(warning.message).lower() for warning in w)
294+
295+
assert resampled.has_time_vector()
296+
resampled_times = resampled.get_times()
297+
assert len(resampled_times) == resampled.get_num_samples()
298+
assert np.isclose(resampled_times[0], 10.0, atol=1.0 / sampling_frequency)
299+
300+
213301
if __name__ == "__main__":
214302
test_resample_freq_domain()
215303
test_resample_by_chunks()
304+
test_resample_preserves_t_start()
305+
test_resample_does_not_mutate_parent()
306+
test_resample_preserves_time_vector_integer_ratio()
307+
test_resample_preserves_time_vector_non_integer_ratio()

0 commit comments

Comments
 (0)