Skip to content

Commit b5afdaf

Browse files
fix unit presence time_range (#4571)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b9a65fc commit b5afdaf

1 file changed

Lines changed: 15 additions & 7 deletions

File tree

src/spikeinterface/widgets/unit_presence.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
7373
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
7474

7575
sorting = dp.sorting
76+
unit_ids = sorting.unit_ids
7677

7778
spikes = sorting.to_spike_vector(concatenated=False, use_cache=True)
7879
spikes = spikes[dp.segment_index]
@@ -85,18 +86,20 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
8586
ind1 = int(t1 * fs)
8687
mask = (spikes["sample_index"] >= ind0) & (spikes["sample_index"] <= ind1)
8788
spikes = spikes[mask]
89+
duration = t1 - t0
90+
else:
91+
last = spikes["sample_index"][-1]
92+
duration = last / fs
8893

8994
if spikes.size == 0:
9095
return
9196

92-
last = spikes["sample_index"][-1]
93-
max_time = last / fs
94-
95-
num_units = len(sorting.unit_ids)
96-
num_time_bins = int(max_time / dp.bin_duration_s) + 1
97+
num_units = len(unit_ids)
98+
num_time_bins = int(duration / dp.bin_duration_s) + 1
9799
map = np.zeros((num_units, num_time_bins))
98100
ind0 = spikes["unit_index"]
99-
ind1 = spikes["sample_index"] // int(dp.bin_duration_s * fs)
101+
offset = int(t0 * fs) if dp.time_range is not None else 0
102+
ind1 = (spikes["sample_index"] - offset) // int(dp.bin_duration_s * fs)
100103
map[ind0, ind1] += 1
101104

102105
if dp.smooth_sigma is not None:
@@ -109,8 +112,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
109112
smooth_kernel = smooth_kernel[np.newaxis, :]
110113
map = scipy.signal.oaconvolve(map, smooth_kernel, mode="same", axes=1)
111114

112-
im = self.ax.matshow(map, cmap="inferno", aspect="auto")
115+
extent = (dp.time_range[0], dp.time_range[1], len(unit_ids), 0) if dp.time_range is not None else None
116+
im = self.ax.imshow(
117+
map, cmap="inferno", aspect="auto", interpolation="nearest", extent=extent, vmax=1.0, vmin=0.0
118+
)
113119
self.ax.set_xlabel("Time (s)")
114120
self.ax.set_ylabel("Units")
121+
self.ax.set_yticks(np.arange(len(unit_ids)) + 0.5)
122+
self.ax.set_yticklabels(unit_ids)
115123

116124
self.figure.colorbar(im)

0 commit comments

Comments
 (0)