Skip to content

Commit 1518fe6

Browse files
committed
Implement copy/get_slice/get_global_contact_positions for probegroup
1 parent 2789811 commit 1518fe6

2 files changed

Lines changed: 279 additions & 0 deletions

File tree

src/probeinterface/probegroup.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ def _check_compatible(self, probe: Probe):
5050
def ndim(self):
5151
return self.probes[0].ndim
5252

53+
def copy(self) -> "ProbeGroup":
54+
"""
55+
Create a copy of the ProbeGroup
56+
57+
Returns
58+
-------
59+
copy: ProbeGroup
60+
A copy of the ProbeGroup
61+
"""
62+
copy = ProbeGroup()
63+
for probe in self.probes:
64+
copy.add_probe(probe.copy())
65+
global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"]
66+
copy.set_global_device_channel_indices(global_device_channel_indices)
67+
return copy
68+
5369
def get_contact_count(self) -> int:
5470
"""
5571
Total number of channels.
@@ -249,6 +265,79 @@ def get_global_contact_ids(self) -> np.ndarray:
249265
contact_ids = self.to_numpy(complete=True)["contact_ids"]
250266
return contact_ids
251267

268+
def get_global_contact_positions(self) -> np.ndarray:
269+
"""
270+
Gets all contact positions concatenated across probes
271+
272+
Returns
273+
-------
274+
contact_positions: np.ndarray
275+
An array of the contact positions across all probes
276+
"""
277+
contact_positions = np.vstack([probe.contact_positions for probe in self.probes])
278+
return contact_positions
279+
280+
def get_slice(self, selection: np.ndarray[bool | int]):
281+
"""
282+
Get a copy of the ProbeGroup with a sub selection of contacts.
283+
284+
Selection can be boolean or by index
285+
286+
Parameters
287+
----------
288+
selection : np.array of bool or int (for index)
289+
Either an np.array of bool or for desired selection of contacts
290+
or the indices of the desired contacts
291+
292+
Returns
293+
-------
294+
sliced_probe_group: ProbeGroup
295+
The sliced probe group
296+
297+
"""
298+
299+
n = self.get_contact_count()
300+
301+
selection = np.asarray(selection)
302+
if selection.dtype == "bool":
303+
assert selection.shape == (n,), (
304+
f"if array of bool given it must be the same size " "as the number of contacts {selection.shape} != {n}"
305+
)
306+
selection_indices, = np.nonzero(selection)
307+
elif selection.dtype.kind == "i":
308+
assert np.unique(selection).size == selection.size
309+
if len(selection) > 0:
310+
assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}"
311+
assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}"
312+
selection_indices = selection
313+
else:
314+
selection_indices = []
315+
else:
316+
raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}")
317+
318+
if len(selection_indices) == 0:
319+
return ProbeGroup()
320+
321+
# Map selection to indices of individual probes
322+
d = self.to_dict(array_as_list=False)
323+
ind = 0
324+
sliced_probes = []
325+
for probe in self.probes:
326+
n = probe.get_contact_count()
327+
probe_limits = (ind, ind + n)
328+
ind += n
329+
330+
probe_selection_indices = selection_indices[(selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1])]
331+
if len(probe_selection_indices) == 0:
332+
continue
333+
sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0])
334+
sliced_probes.append(sliced_probe)
335+
336+
sliced_probe_group = ProbeGroup()
337+
sliced_probe_group.probes = sliced_probes
338+
339+
return sliced_probe_group
340+
252341
def check_global_device_wiring_and_ids(self):
253342
# check unique device_channel_indices for !=-1
254343
chans = self.get_global_device_channel_indices()

tests/test_probegroup.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,196 @@ def test_set_contact_ids_rejects_wrong_size():
116116
probe.set_contact_ids(["a", "b", "c"])
117117

118118

119+
def _make_probegroup(n_probes=3):
120+
"""Helper: build a ProbeGroup with device channel indices set."""
121+
probegroup = ProbeGroup()
122+
nchan = 0
123+
for i in range(n_probes):
124+
probe = generate_dummy_probe()
125+
probe.move([i * 100, i * 80])
126+
n = probe.get_contact_count()
127+
probe.set_device_channel_indices(np.arange(n) + nchan)
128+
probe.set_contact_ids([f"e{j}" for j in range(nchan, nchan + n)])
129+
nchan += n
130+
probegroup.add_probe(probe)
131+
return probegroup
132+
133+
134+
def _make_probegroup_full(n_probes=3):
135+
"""Helper: build a ProbeGroup where **every** probe is added."""
136+
probegroup = ProbeGroup()
137+
nchan = 0
138+
for i in range(n_probes):
139+
probe = generate_dummy_probe()
140+
probe.move([i * 100, i * 80])
141+
n = probe.get_contact_count()
142+
probe.set_device_channel_indices(np.arange(n) + nchan)
143+
probe.set_contact_ids([f"e{j}" for j in range(nchan, nchan + n)])
144+
probegroup.add_probe(probe)
145+
nchan += n
146+
return probegroup
147+
148+
149+
# ── copy() tests ────────────────────────────────────────────────────────────
150+
151+
152+
def test_copy_returns_new_object():
153+
pg = _make_probegroup_full(2)
154+
pg_copy = pg.copy()
155+
assert pg_copy is not pg
156+
assert len(pg_copy.probes) == len(pg.probes)
157+
for orig, copied in zip(pg.probes, pg_copy.probes):
158+
assert orig is not copied
159+
160+
161+
def test_copy_preserves_positions():
162+
pg = _make_probegroup_full(2)
163+
pg_copy = pg.copy()
164+
for orig, copied in zip(pg.probes, pg_copy.probes):
165+
np.testing.assert_array_equal(orig.contact_positions, copied.contact_positions)
166+
167+
168+
def test_copy_preserves_device_channel_indices():
169+
pg = _make_probegroup_full(2)
170+
pg_copy = pg.copy()
171+
np.testing.assert_array_equal(
172+
pg.get_global_device_channel_indices(),
173+
pg_copy.get_global_device_channel_indices(),
174+
)
175+
176+
177+
def test_copy_does_not_preserve_contact_ids():
178+
"""Probe.copy() intentionally does not copy contact_ids."""
179+
pg = _make_probegroup_full(2)
180+
pg_copy = pg.copy()
181+
# All contact_ids should be empty strings after copy
182+
assert all(cid == "" for cid in pg_copy.get_global_contact_ids())
183+
184+
185+
def test_copy_is_independent():
186+
"""Mutating the copy must not affect the original."""
187+
pg = _make_probegroup_full(2)
188+
original_positions = pg.probes[0].contact_positions.copy()
189+
pg_copy = pg.copy()
190+
pg_copy.probes[0].move([999, 999])
191+
np.testing.assert_array_equal(pg.probes[0].contact_positions, original_positions)
192+
193+
194+
# ── get_slice() tests ───────────────────────────────────────────────────────
195+
196+
197+
def test_get_slice_by_bool():
198+
pg = _make_probegroup_full(2)
199+
total = pg.get_contact_count()
200+
sel = np.zeros(total, dtype=bool)
201+
sel[:5] = True # first 5 contacts from the first probe
202+
sliced = pg.get_slice(sel)
203+
assert sliced.get_contact_count() == 5
204+
205+
206+
def test_get_slice_by_index():
207+
pg = _make_probegroup_full(2)
208+
indices = np.array([0, 1, 2, 33, 34]) # contacts from both probes
209+
sliced = pg.get_slice(indices)
210+
assert sliced.get_contact_count() == 5
211+
212+
213+
def test_get_slice_preserves_device_channel_indices():
214+
pg = _make_probegroup_full(2)
215+
indices = np.array([0, 1, 2])
216+
sliced = pg.get_slice(indices)
217+
orig_chans = pg.get_global_device_channel_indices()["device_channel_indices"][:3]
218+
sliced_chans = sliced.get_global_device_channel_indices()["device_channel_indices"]
219+
np.testing.assert_array_equal(sliced_chans, orig_chans)
220+
221+
222+
def test_get_slice_preserves_positions():
223+
pg = _make_probegroup_full(2)
224+
indices = np.array([0, 1, 2])
225+
sliced = pg.get_slice(indices)
226+
expected = pg.get_global_contact_positions()[indices]
227+
np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected)
228+
229+
230+
def test_get_slice_empty_selection():
231+
pg = _make_probegroup_full(2)
232+
sliced = pg.get_slice(np.array([], dtype=int))
233+
assert sliced.get_contact_count() == 0
234+
assert len(sliced.probes) == 0
235+
236+
237+
def test_get_slice_wrong_bool_size():
238+
pg = _make_probegroup_full(2)
239+
with pytest.raises(AssertionError):
240+
pg.get_slice(np.array([True, False])) # wrong size
241+
242+
243+
def test_get_slice_out_of_bounds():
244+
pg = _make_probegroup_full(2)
245+
total = pg.get_contact_count()
246+
with pytest.raises(AssertionError):
247+
pg.get_slice(np.array([total + 10]))
248+
249+
250+
def test_get_slice_all_contacts():
251+
"""Slicing with all contacts should give an equivalent ProbeGroup."""
252+
pg = _make_probegroup_full(2)
253+
total = pg.get_contact_count()
254+
sliced = pg.get_slice(np.arange(total))
255+
assert sliced.get_contact_count() == total
256+
np.testing.assert_array_equal(
257+
sliced.get_global_contact_positions(),
258+
pg.get_global_contact_positions(),
259+
)
260+
261+
262+
# ── get_global_contact_positions() tests ────────────────────────────────────
263+
264+
265+
def test_get_global_contact_positions_shape():
266+
pg = _make_probegroup_full(3)
267+
pos = pg.get_global_contact_positions()
268+
assert pos.shape == (pg.get_contact_count(), pg.ndim)
269+
270+
271+
def test_get_global_contact_positions_matches_per_probe():
272+
pg = _make_probegroup_full(3)
273+
pos = pg.get_global_contact_positions()
274+
offset = 0
275+
for probe in pg.probes:
276+
n = probe.get_contact_count()
277+
np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions)
278+
offset += n
279+
280+
281+
def test_get_global_contact_positions_single_probe():
282+
pg = _make_probegroup_full(1)
283+
pos = pg.get_global_contact_positions()
284+
np.testing.assert_array_equal(pos, pg.probes[0].contact_positions)
285+
286+
287+
def test_get_global_contact_positions_3d():
288+
pg = ProbeGroup()
289+
for i in range(2):
290+
probe = generate_dummy_probe().to_3d()
291+
probe.move([i * 100, i * 80, i * 30])
292+
pg.add_probe(probe)
293+
pos = pg.get_global_contact_positions()
294+
assert pos.shape[1] == 3
295+
assert pos.shape[0] == pg.get_contact_count()
296+
297+
298+
def test_get_global_contact_positions_reflects_move():
299+
"""Positions should reflect probe movement."""
300+
pg = ProbeGroup()
301+
probe = generate_dummy_probe()
302+
original_pos = probe.contact_positions.copy()
303+
probe.move([50, 60])
304+
pg.add_probe(probe)
305+
pos = pg.get_global_contact_positions()
306+
np.testing.assert_array_equal(pos, original_pos + np.array([50, 60]))
307+
308+
119309
if __name__ == "__main__":
120310
test_probegroup()
121311
# ~ test_probegroup_3d()

0 commit comments

Comments
 (0)