Skip to content

Commit 75370d5

Browse files
authored
Implement copy/get_slice/get_global_contact_positions for probegroup (#416)
1 parent 2789811 commit 75370d5

File tree

2 files changed

+257
-24
lines changed

2 files changed

+257
-24
lines changed

src/probeinterface/probegroup.py

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ProbeGroup:
1414
def __init__(self):
1515
self.probes = []
1616

17-
def add_probe(self, probe: Probe):
17+
def add_probe(self, probe: Probe) -> None:
1818
"""
1919
Add an additional probe to the ProbeGroup
2020
@@ -30,7 +30,7 @@ def add_probe(self, probe: Probe):
3030
self.probes.append(probe)
3131
probe._probe_group = self
3232

33-
def _check_compatible(self, probe: Probe):
33+
def _check_compatible(self, probe: Probe) -> None:
3434
if probe._probe_group is not None:
3535
raise ValueError(
3636
"This probe is already attached to another ProbeGroup. Use probe.copy() to attach it to another ProbeGroup"
@@ -47,9 +47,25 @@ def _check_compatible(self, probe: Probe):
4747
self.probes = self.probes[:-1]
4848

4949
@property
50-
def ndim(self):
50+
def ndim(self) -> int:
5151
return self.probes[0].ndim
5252

53+
def copy(self) -> "ProbeGroup":
54+
"""
55+
Create a copy of the ProbeGroup
56+
57+
Returns
58+
-------
59+
copy: ProbeGroup
60+
A copy of the ProbeGroup
61+
"""
62+
copy = ProbeGroup()
63+
for probe in self.probes:
64+
copy.add_probe(probe.copy())
65+
global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"]
66+
copy.set_global_device_channel_indices(global_device_channel_indices)
67+
return copy
68+
5369
def get_contact_count(self) -> int:
5470
"""
5571
Total number of channels.
@@ -147,7 +163,7 @@ def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
147163
df.index = np.arange(df.shape[0], dtype="int64")
148164
return df
149165

150-
def to_dict(self, array_as_list: bool = False):
166+
def to_dict(self, array_as_list: bool = False) -> dict:
151167
"""Create a dictionary of all necessary attributes.
152168
153169
Parameters
@@ -168,7 +184,7 @@ def to_dict(self, array_as_list: bool = False):
168184
return d
169185

170186
@staticmethod
171-
def from_dict(d: dict):
187+
def from_dict(d: dict) -> "ProbeGroup":
172188
"""Instantiate a ProbeGroup from a dictionary
173189
174190
Parameters
@@ -210,7 +226,7 @@ def get_global_device_channel_indices(self) -> np.ndarray:
210226
channels["device_channel_indices"] = arr["device_channel_indices"]
211227
return channels
212228

213-
def set_global_device_channel_indices(self, channels: np.ndarray | list):
229+
def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None:
214230
"""
215231
Set global indices for all probes
216232
@@ -249,7 +265,86 @@ def get_global_contact_ids(self) -> np.ndarray:
249265
contact_ids = self.to_numpy(complete=True)["contact_ids"]
250266
return contact_ids
251267

252-
def check_global_device_wiring_and_ids(self):
268+
def get_global_contact_positions(self) -> np.ndarray:
269+
"""
270+
Gets all contact positions concatenated across probes
271+
272+
Returns
273+
-------
274+
contact_positions: np.ndarray
275+
An array of the contact positions across all probes
276+
"""
277+
contact_positions = np.vstack([probe.contact_positions for probe in self.probes])
278+
return contact_positions
279+
280+
def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup":
281+
"""
282+
Get a copy of the ProbeGroup with a sub selection of contacts.
283+
284+
Selection can be boolean or by index
285+
286+
Parameters
287+
----------
288+
selection : np.array of bool or int (for index)
289+
Either an np.array of bool or for desired selection of contacts
290+
or the indices of the desired contacts
291+
292+
Returns
293+
-------
294+
sliced_probe_group: ProbeGroup
295+
The sliced probe group
296+
297+
"""
298+
299+
n = self.get_contact_count()
300+
301+
selection = np.asarray(selection)
302+
if selection.dtype == "bool":
303+
assert selection.shape == (
304+
n,
305+
), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}"
306+
(selection_indices,) = np.nonzero(selection)
307+
elif selection.dtype.kind == "i":
308+
assert np.unique(selection).size == selection.size
309+
if len(selection) > 0:
310+
assert (
311+
0 <= np.min(selection) < n
312+
), f"An index within your selection is out of bounds {np.min(selection)}"
313+
assert (
314+
0 <= np.max(selection) < n
315+
), f"An index within your selection is out of bounds {np.max(selection)}"
316+
selection_indices = selection
317+
else:
318+
selection_indices = []
319+
else:
320+
raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}")
321+
322+
if len(selection_indices) == 0:
323+
return ProbeGroup()
324+
325+
# Map selection to indices of individual probes
326+
ind = 0
327+
sliced_probes = []
328+
for probe in self.probes:
329+
n = probe.get_contact_count()
330+
probe_limits = (ind, ind + n)
331+
ind += n
332+
333+
probe_selection_indices = selection_indices[
334+
(selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1])
335+
]
336+
if len(probe_selection_indices) == 0:
337+
continue
338+
sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0])
339+
sliced_probes.append(sliced_probe)
340+
341+
sliced_probe_group = ProbeGroup()
342+
for probe in sliced_probes:
343+
sliced_probe_group.add_probe(probe)
344+
345+
return sliced_probe_group
346+
347+
def check_global_device_wiring_and_ids(self) -> None:
253348
# check unique device_channel_indices for !=-1
254349
chans = self.get_global_device_channel_indices()
255350
keep = chans["device_channel_indices"] >= 0
@@ -258,7 +353,7 @@ def check_global_device_wiring_and_ids(self):
258353
if valid_chans.size != np.unique(valid_chans).size:
259354
raise ValueError("channel device indices are not unique across probes")
260355

261-
def auto_generate_probe_ids(self, *args, **kwargs):
356+
def auto_generate_probe_ids(self, *args, **kwargs) -> None:
262357
"""
263358
Annotate all probes with unique probe_id values.
264359
@@ -282,7 +377,7 @@ def auto_generate_probe_ids(self, *args, **kwargs):
282377
for pid, probe in enumerate(self.probes):
283378
probe.annotate(probe_id=probe_ids[pid])
284379

285-
def auto_generate_contact_ids(self, *args, **kwargs):
380+
def auto_generate_contact_ids(self, *args, **kwargs) -> None:
286381
"""
287382
Annotate all contacts with unique contact_id values.
288383

tests/test_probegroup.py

Lines changed: 153 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,27 @@
66
import numpy as np
77

88

9-
def test_probegroup():
9+
@pytest.fixture
10+
def probegroup():
11+
"""Fixture: a ProbeGroup with 3 probes, each with device channel indices set."""
1012
probegroup = ProbeGroup()
11-
1213
nchan = 0
1314
for i in range(3):
1415
probe = generate_dummy_probe()
1516
probe.move([i * 100, i * 80])
1617
n = probe.get_contact_count()
17-
probe.set_device_channel_indices(np.arange(n)[::-1] + nchan)
18-
shank_ids = np.ones(n)
19-
shank_ids[: n // 2] *= i * 2
20-
shank_ids[n // 2 :] *= i * 2 + 1
21-
probe.set_shank_ids(shank_ids)
18+
probe.set_device_channel_indices(np.arange(n) + nchan)
2219
probegroup.add_probe(probe)
23-
2420
nchan += n
21+
return probegroup
2522

23+
24+
def test_probegroup(probegroup):
2625
indices = probegroup.get_global_device_channel_indices()
2726

2827
ids = probegroup.get_global_contact_ids()
2928

3029
df = probegroup.to_dataframe()
31-
# ~ print(df['global_contact_ids'])
3230

3331
arr = probegroup.to_numpy(complete=False)
3432
other = ProbeGroup.from_numpy(arr)
@@ -38,12 +36,6 @@ def test_probegroup():
3836
d = probegroup.to_dict()
3937
other = ProbeGroup.from_dict(d)
4038

41-
# ~ from probeinterface.plotting import plot_probe_group, plot_probe
42-
# ~ import matplotlib.pyplot as plt
43-
# ~ plot_probe_group(probegroup)
44-
# ~ plot_probe_group(other)
45-
# ~ plt.show()
46-
4739
# checking automatic generation of ids with new dummy probes
4840
probegroup.probes = []
4941
for i in range(3):
@@ -116,6 +108,152 @@ def test_set_contact_ids_rejects_wrong_size():
116108
probe.set_contact_ids(["a", "b", "c"])
117109

118110

111+
# ── get_global_contact_positions() tests ────────────────────────────────────
112+
113+
114+
def test_get_global_contact_positions_shape(probegroup):
115+
pos = probegroup.get_global_contact_positions()
116+
assert pos.shape == (probegroup.get_contact_count(), probegroup.ndim)
117+
118+
119+
def test_get_global_contact_positions_matches_per_probe(probegroup):
120+
pos = probegroup.get_global_contact_positions()
121+
offset = 0
122+
for probe in probegroup.probes:
123+
n = probe.get_contact_count()
124+
np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions)
125+
offset += n
126+
127+
128+
def test_get_global_contact_positions_single_probe(probegroup):
129+
pos = probegroup.get_global_contact_positions()
130+
np.testing.assert_array_equal(
131+
pos[: probegroup.probes[0].get_contact_count()], probegroup.probes[0].contact_positions
132+
)
133+
134+
135+
def test_get_global_contact_positions_3d():
136+
pg = ProbeGroup()
137+
for i in range(2):
138+
probe = generate_dummy_probe().to_3d()
139+
probe.move([i * 100, i * 80, i * 30])
140+
pg.add_probe(probe)
141+
pos = pg.get_global_contact_positions()
142+
assert pos.shape[1] == 3
143+
assert pos.shape[0] == pg.get_contact_count()
144+
145+
146+
def test_get_global_contact_positions_reflects_move():
147+
"""Positions should reflect probe movement."""
148+
pg = ProbeGroup()
149+
probe = generate_dummy_probe()
150+
original_pos = probe.contact_positions.copy()
151+
probe.move([50, 60])
152+
pg.add_probe(probe)
153+
pos = pg.get_global_contact_positions()
154+
np.testing.assert_array_equal(pos, original_pos + np.array([50, 60]))
155+
156+
157+
# ── copy() tests ────────────────────────────────────────────────────────────
158+
159+
160+
def test_copy_returns_new_object(probegroup):
161+
pg_copy = probegroup.copy()
162+
assert pg_copy is not probegroup
163+
assert len(pg_copy.probes) == len(probegroup.probes)
164+
for orig, copied in zip(probegroup.probes, pg_copy.probes):
165+
assert orig is not copied
166+
167+
168+
def test_copy_preserves_positions(probegroup):
169+
pg_copy = probegroup.copy()
170+
for orig, copied in zip(probegroup.probes, pg_copy.probes):
171+
np.testing.assert_array_equal(orig.contact_positions, copied.contact_positions)
172+
173+
174+
def test_copy_preserves_device_channel_indices(probegroup):
175+
pg_copy = probegroup.copy()
176+
np.testing.assert_array_equal(
177+
probegroup.get_global_device_channel_indices(),
178+
pg_copy.get_global_device_channel_indices(),
179+
)
180+
181+
182+
def test_copy_does_not_preserve_contact_ids(probegroup):
183+
"""Probe.copy() intentionally does not copy contact_ids."""
184+
pg_copy = probegroup.copy()
185+
# All contact_ids should be empty strings after copy
186+
assert all(cid == "" for cid in pg_copy.get_global_contact_ids())
187+
188+
189+
def test_copy_is_independent(probegroup):
190+
"""Mutating the copy must not affect the original."""
191+
original_positions = probegroup.probes[0].contact_positions.copy()
192+
pg_copy = probegroup.copy()
193+
pg_copy.probes[0].move([999, 999])
194+
np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions)
195+
196+
197+
# ── get_slice() tests ───────────────────────────────────────────────────────
198+
199+
200+
def test_get_slice_by_bool(probegroup):
201+
total = probegroup.get_contact_count()
202+
sel = np.zeros(total, dtype=bool)
203+
sel[:5] = True # first 5 contacts from the first probe
204+
sliced = probegroup.get_slice(sel)
205+
assert sliced.get_contact_count() == 5
206+
207+
208+
def test_get_slice_by_index(probegroup):
209+
indices = np.array([0, 1, 2, 33, 34]) # contacts from both probes
210+
sliced = probegroup.get_slice(indices)
211+
assert sliced.get_contact_count() == 5
212+
213+
214+
def test_get_slice_preserves_device_channel_indices(probegroup):
215+
indices = np.array([0, 1, 2])
216+
sliced = probegroup.get_slice(indices)
217+
orig_chans = probegroup.get_global_device_channel_indices()["device_channel_indices"][:3]
218+
sliced_chans = sliced.get_global_device_channel_indices()["device_channel_indices"]
219+
np.testing.assert_array_equal(sliced_chans, orig_chans)
220+
221+
222+
def test_get_slice_preserves_positions(probegroup):
223+
indices = np.array([0, 1, 2])
224+
sliced = probegroup.get_slice(indices)
225+
expected = probegroup.get_global_contact_positions()[indices]
226+
np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected)
227+
228+
229+
def test_get_slice_empty_selection(probegroup):
230+
sliced = probegroup.get_slice(np.array([], dtype=int))
231+
assert sliced.get_contact_count() == 0
232+
assert len(sliced.probes) == 0
233+
234+
235+
def test_get_slice_wrong_bool_size(probegroup):
236+
with pytest.raises(AssertionError):
237+
probegroup.get_slice(np.array([True, False])) # wrong size
238+
239+
240+
def test_get_slice_out_of_bounds(probegroup):
241+
total = probegroup.get_contact_count()
242+
with pytest.raises(AssertionError):
243+
probegroup.get_slice(np.array([total + 10]))
244+
245+
246+
def test_get_slice_all_contacts(probegroup):
247+
"""Slicing with all contacts should give an equivalent ProbeGroup."""
248+
total = probegroup.get_contact_count()
249+
sliced = probegroup.get_slice(np.arange(total))
250+
assert sliced.get_contact_count() == total
251+
np.testing.assert_array_equal(
252+
sliced.get_global_contact_positions(),
253+
probegroup.get_global_contact_positions(),
254+
)
255+
256+
119257
if __name__ == "__main__":
120258
test_probegroup()
121259
# ~ test_probegroup_3d()

0 commit comments

Comments
 (0)