Skip to content

Commit c13f166

Browse files
committed
Refactor probegroup to be array-based
1 parent 192dc86 commit c13f166

5 files changed

Lines changed: 192 additions & 163 deletions

File tree

src/probeinterface/io.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,7 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup
202202

203203
# create probe object and register with probegroup
204204
probe = Probe.from_dataframe(df=df_probe)
205-
probe.annotate(probe_id=probe_id)
206-
207205
probes[str(probe_id)] = probe
208-
probegroup.add_probe(probe)
209206

210207
ignore_annotations = [
211208
"probe_ids",
@@ -294,6 +291,10 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup
294291

295292
probe.annotate(**{contact_param: value_list})
296293

294+
# Step 5: add probes to probegroup
295+
for probe in probes.values():
296+
probegroup.add_probe(probe)
297+
297298
return probegroup
298299

299300

@@ -337,10 +338,6 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup
337338

338339
# Step 1: GENERATION OF PROBE.TSV
339340
# ensure required keys (probe_id, probe_type) are present
340-
341-
if any("probe_id" not in p.annotations for p in probes):
342-
probegroup.auto_generate_probe_ids()
343-
344341
for probe in probes:
345342
if "probe_id" not in probe.annotations:
346343
raise ValueError(

src/probeinterface/probe.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def __init__(
9999
# vertices for the shape of the probe
100100
self.probe_planar_contour = None
101101

102+
# the Probe can belong to a ProbeGroup
103+
self._probe_group = None
104+
102105
# This handles the shank id per contact
103106
# If None then one shank only
104107
self._shank_ids = None
@@ -129,9 +132,6 @@ def __init__(
129132
# same idea but handle in vector way for contacts
130133
self.contact_annotations = dict()
131134

132-
# the Probe can belong to a ProbeGroup
133-
self._probe_group = None
134-
135135
@property
136136
def contact_positions(self):
137137
"""The position of the center for each contact"""
@@ -260,6 +260,11 @@ def annotate(self, **kwargs):
260260
----------
261261
**kwargs : list of keyword arguments to add to the annotations (e.g., brain_area="CA1")
262262
"""
263+
if self._probe_group is not None:
264+
raise ValueError(
265+
"You cannot annotate a probe that belongs to a ProbeGroup. "
266+
"Annotate the probe before adding it to the ProbeGroup or use the `ProbeGroup.annotate_probe` method."
267+
)
263268
self.annotations.update(kwargs)
264269
self.check_annotations()
265270

@@ -271,6 +276,11 @@ def annotate_contacts(self, **kwargs):
271276
----------
272277
**kwargs : list of keyword arguments to add to the annotations (e.g., quality=["good", "bad", ...])
273278
"""
279+
if self._probe_group is not None:
280+
raise ValueError(
281+
"You cannot annotate contacts of a probe that belongs to a ProbeGroup. "
282+
"Annotate the probe before adding it to the ProbeGroup instead."
283+
)
274284
n = self.get_contact_count()
275285
for k, values in kwargs.items():
276286
assert len(values) == n, (
@@ -977,7 +987,7 @@ def from_dict(d: dict) -> "Probe":
977987

978988
return probe
979989

980-
def to_numpy(self, complete: bool = False) -> np.ndarray:
990+
def to_numpy(self, complete: bool = False, probe_index: int | None = None) -> np.ndarray:
981991
"""
982992
Export the probe to a numpy structured array.
983993
This array handles all contact attributes.
@@ -1035,7 +1045,10 @@ def to_numpy(self, complete: bool = False) -> np.ndarray:
10351045
"""
10361046

10371047
# First define the dtype
1038-
dtype = [("x", "float64"), ("y", "float64")]
1048+
dtype = []
1049+
if probe_index is not None:
1050+
dtype = [("probe_index", "int64")]
1051+
dtype += [("x", "float64"), ("y", "float64")]
10391052
if self.ndim == 3:
10401053
dtype += [("z", "float64")]
10411054

@@ -1070,6 +1083,8 @@ def to_numpy(self, complete: bool = False) -> np.ndarray:
10701083

10711084
# Then add the data to the structured array
10721085
arr = np.zeros(self.get_contact_count(), dtype=dtype)
1086+
if probe_index is not None:
1087+
arr["probe_index"] = probe_index
10731088
arr["x"] = self.contact_positions[:, 0]
10741089
arr["y"] = self.contact_positions[:, 1]
10751090
if self.ndim == 3:

0 commit comments

Comments
 (0)