Skip to content

Commit 92e9253

Browse files
committed
auto-set contact ids
1 parent 192dc86 commit 92e9253

4 files changed

Lines changed: 45 additions & 25 deletions

File tree

src/probeinterface/io.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,6 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup
389389
json.dump({"ProbeId": probes_dict}, f, indent=4)
390390

391391
# Step 3: GENERATION OF CONTACTS.TSV
392-
# ensure required contact identifiers are present
393-
for probe in probes:
394-
if probe.contact_ids is None:
395-
raise ValueError(
396-
"Contacts must have unique contact ids "
397-
"and not None for export to BIDS probe format."
398-
"Use `probegroup.auto_generate_contact_ids`."
399-
)
400-
401392
df = probegroup.to_dataframe()
402393
index = range(sum([p.get_contact_count() for p in probes]))
403394
df.rename(columns=tsv_label_map_to_BIDS, inplace=True)

src/probeinterface/probe.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,10 @@ def set_contacts(
336336
Defines the two axes of the contact plane for each electrode.
337337
The third dimension corresponds to the probe `ndim` (2d or 3d).
338338
contact_ids: array[str] | None, default: None
339-
Defines the contact ids for the contacts. If None, contact ids are not assigned.
339+
Defines the contact ids for the contacts. If None, contact ids are
340+
auto-generated as the zero-indexed strings ``["0", "1", ..., str(n - 1)]``
341+
so a Probe always carries a stable, slice-invariant handle for each
342+
contact. Pass an explicit array to override.
340343
shank_ids : array[str] | None, default: None
341344
Defines the shank ids for the contacts. If None, then
342345
these are assigned to a unique Shank.
@@ -378,8 +381,9 @@ def set_contacts(
378381
plane_axes = np.array(plane_axes)
379382
self._contact_plane_axes = plane_axes
380383

381-
if contact_ids is not None:
382-
self.set_contact_ids(contact_ids)
384+
if contact_ids is None:
385+
contact_ids = np.arange(n).astype(str)
386+
self.set_contact_ids(contact_ids)
383387

384388
if shank_ids is None:
385389
# self._shank_ids = np.zeros(n, dtype=str)
@@ -566,8 +570,9 @@ def set_contact_ids(self, contact_ids: np.ndarray | list):
566570
"""
567571
contact_ids = np.asarray(contact_ids)
568572
if np.all([c == "" for c in contact_ids]):
569-
self._contact_ids = None
570-
return
573+
# Backward compat: previous versions serialized "unset" as empty
574+
# strings. A Probe now always carries contact_ids, so regenerate.
575+
contact_ids = np.arange(self.get_contact_count()).astype(str)
571576

572577
if contact_ids.size != self.get_contact_count():
573578
raise ValueError(
@@ -1085,10 +1090,7 @@ def to_numpy(self, complete: bool = False) -> np.ndarray:
10851090
if self._contact_sides is not None:
10861091
arr["contact_sides"] = self.contact_sides
10871092

1088-
if self.contact_ids is None:
1089-
arr["contact_ids"] = [""] * self.get_contact_count()
1090-
else:
1091-
arr["contact_ids"] = self.contact_ids
1093+
arr["contact_ids"] = self.contact_ids
10921094

10931095
if complete:
10941096
arr["si_units"] = self.si_units

tests/test_probe.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,40 @@ def test_probe():
141141
# ~ plt.show()
142142

143143

144+
def test_set_contacts_auto_generates_contact_ids():
145+
"""When contact_ids is not supplied, Probe auto-generates ['0', ..., str(n-1)]."""
146+
probe = Probe(ndim=2, si_units="um")
147+
positions = np.array([[0, 0], [10, 0], [20, 0], [30, 0]])
148+
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
149+
150+
assert probe.contact_ids is not None
151+
np.testing.assert_array_equal(probe.contact_ids, np.array(["0", "1", "2", "3"]))
152+
153+
154+
def test_set_contacts_respects_explicit_contact_ids():
155+
"""An explicit contact_ids argument is preserved verbatim."""
156+
probe = Probe(ndim=2, si_units="um")
157+
positions = np.array([[0, 0], [10, 0], [20, 0]])
158+
probe.set_contacts(
159+
positions=positions,
160+
shapes="circle",
161+
shape_params={"radius": 5},
162+
contact_ids=["a", "b", "c"],
163+
)
164+
165+
np.testing.assert_array_equal(probe.contact_ids, np.array(["a", "b", "c"]))
166+
167+
168+
def test_set_contact_ids_all_empty_strings_regenerates():
169+
"""Backward compat: older serialized probes used empty strings for 'unset'."""
170+
probe = Probe(ndim=2, si_units="um")
171+
positions = np.array([[0, 0], [10, 0], [20, 0]])
172+
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
173+
probe.set_contact_ids(["", "", ""])
174+
175+
np.testing.assert_array_equal(probe.contact_ids, np.array(["0", "1", "2"]))
176+
177+
144178
def test_probe_equality_dunder():
145179
probe1 = generate_dummy_probe()
146180
probe2 = generate_dummy_probe()

tests/test_probegroup.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,6 @@ def test_copy_preserves_device_channel_indices(probegroup):
179179
)
180180

181181

182-
def test_copy_does_not_preserve_contact_ids(probegroup):
183-
"""Probe.copy() intentionally does not copy contact_ids."""
184-
pg_copy = probegroup.copy()
185-
# All contact_ids should be empty strings after copy
186-
assert all(cid == "" for cid in pg_copy.get_global_contact_ids())
187-
188-
189182
def test_copy_is_independent(probegroup):
190183
"""Mutating the copy must not affect the original."""
191184
original_positions = probegroup.probes[0].contact_positions.copy()

0 commit comments

Comments
 (0)