Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 73 additions & 43 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ class ProbeGroup:
"""
Class to handle a group of Probe objects and the global wiring to a device.

Optionally, it can handle the location of different probes.
Internally, this is represented as a list of Probe object.

The ProbeGroup is the object saved in the json based probeinterface format, even if there only one probe.

Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order
is the "natural" one (stacked order of each probe). But optionally, this order can be more complex, for instance
some contact of each probe are interleaved, in this case a optional reordering can be applied.



"""

def __init__(self):
self.probes = []
self._global_contact_order = None

def add_probe(self, probe: Probe) -> None:
"""
Expand Down Expand Up @@ -114,13 +123,20 @@ def to_numpy(self, complete: bool = False) -> np.ndarray:
pg_arr.append(arr_ext)

pg_arr = np.concatenate(pg_arr, axis=0)

if self._global_contact_order is not None:
pg_arr = pg_arr[self._global_contact_order]
return pg_arr

@staticmethod
def from_numpy(arr: np.ndarray) -> "ProbeGroup":
"""Create ProbeGroup from a complex numpy array
see ProbeGroup.to_numpy()

Note that if the contact_vector has several probe and some contact are interleaved, then the ProbeGroup will
have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order
will be not None.

Parameters
----------
arr : np.array
Expand All @@ -131,14 +147,27 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup":
probegroup : ProbeGroup
The instantiated ProbeGroup object
"""
from .probe import Probe

# Check if contacts are interleaved
num_probes = np.unique(arr["probe_index"]).size
is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0)
if is_interleaved:
global_contact_order = []

probes_indices = np.unique(arr["probe_index"])
probegroup = ProbeGroup()
for probe_index in probes_indices:
mask = arr["probe_index"] == probe_index
probe = Probe.from_numpy(arr[mask])
probegroup.add_probe(probe)

if is_interleaved:
global_contact_order.append(np.flatnonzero(mask))

if is_interleaved:
# the argsort is for the 'reverse' order!
probegroup._global_contact_order = np.argsort(np.concatenate(global_contact_order))

return probegroup

def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
Expand Down Expand Up @@ -181,6 +210,11 @@ def to_dict(self, array_as_list: bool = False) -> dict:
for probe_ind, probe in enumerate(self.probes):
probe_dict = probe.to_dict(array_as_list=array_as_list)
d["probes"].append(probe_dict)
if self._global_contact_order is not None:
global_contact_order = self._global_contact_order
if array_as_list:
global_contact_order = global_contact_order.to_list()
d["global_contact_order"] = global_contact_order
return d

@staticmethod
Expand All @@ -201,6 +235,11 @@ def from_dict(d: dict) -> "ProbeGroup":
for probe_dict in d["probes"]:
probe = Probe.from_dict(probe_dict)
probegroup.add_probe(probe)

global_contact_order = d.get("global_contact_order", None)
if global_contact_order is not None:
probegroup._global_contact_order = np.asarray(global_contact_order)

return probegroup

def get_global_device_channel_indices(self) -> np.ndarray:
Expand All @@ -226,31 +265,40 @@ def get_global_device_channel_indices(self) -> np.ndarray:
channels["device_channel_indices"] = arr["device_channel_indices"]
return channels

def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None:
def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | list) -> None:
"""
Set global indices for all probes
Set global indices for all probes.

Important note : if the order of contacts is not "natural" then the device_channel_indices
is applied is the real/reordered contacts vector. In short, the device_channel_indices is zipped to
ProbeGroup.to_numpy() (always ordered).

Parameters
----------
channels: np.ndarray | list
The device channal indices to be set
"""
channels = np.asarray(channels)
if channels.size != self.get_contact_count():
device_channel_indices = np.asarray(device_channel_indices)
if device_channel_indices.size != self.get_contact_count():
raise ValueError(
f"Wrong channels size {channels.size} for the number of channels {self.get_contact_count()}"
f"Wrong channels size {device_channel_indices.size} for the number of channels {self.get_contact_count()}"
)

# first reset previous indices
for i, probe in enumerate(self.probes):
n = probe.get_contact_count()
probe.set_device_channel_indices([-1] * n)

if self._global_contact_order is not None:
# this is tricky conceptually but needed for consistency
rev_order = np.argsort(self._global_contact_order)
device_channel_indices = device_channel_indices[rev_order]

# then set new indices
ind = 0
for i, probe in enumerate(self.probes):
n = probe.get_contact_count()
probe.set_device_channel_indices(channels[ind : ind + n])
probe.set_device_channel_indices(device_channel_indices[ind : ind + n])
ind += n

def get_global_contact_ids(self) -> np.ndarray:
Expand All @@ -275,6 +323,8 @@ def get_global_contact_positions(self) -> np.ndarray:
An array of the contact positions across all probes
"""
contact_positions = np.vstack([probe.contact_positions for probe in self.probes])
if self._global_contact_order is not None:
contact_positions = contact_positions[self._global_contact_order]
return contact_positions

def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup":
Expand All @@ -299,48 +349,28 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup":
n = self.get_contact_count()

selection = np.asarray(selection)
if selection.dtype.kind not in ("b", "i"):
raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}")

if selection.dtype == "bool":
assert selection.shape == (
n,
), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}"
(selection_indices,) = np.nonzero(selection)
elif selection.dtype.kind == "i":
assert np.unique(selection).size == selection.size
if len(selection) > 0:
assert (
0 <= np.min(selection) < n
), f"An index within your selection is out of bounds {np.min(selection)}"
assert (
0 <= np.max(selection) < n
), f"An index within your selection is out of bounds {np.max(selection)}"
selection_indices = selection
else:
selection_indices = []
else:
raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}")
selection = np.flatnonzero(selection)

if len(selection_indices) == 0:
return ProbeGroup()
if len(selection) == 0:
raise ValueError("ProbeGroup.get_slice() with empty selection is not handled")
# return ProbeGroup()

# Map selection to indices of individual probes
ind = 0
sliced_probes = []
for probe in self.probes:
n = probe.get_contact_count()
probe_limits = (ind, ind + n)
ind += n
assert np.unique(selection).size == selection.size
assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}"
assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}"

contact_arr = self.to_numpy(complete=True)
contact_arr = contact_arr[selection]
sliced_probe_group = ProbeGroup.from_numpy(contact_arr)

probe_selection_indices = selection_indices[
(selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1])
]
if len(probe_selection_indices) == 0:
continue
sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0])
sliced_probes.append(sliced_probe)

sliced_probe_group = ProbeGroup()
for probe in sliced_probes:
sliced_probe_group.add_probe(probe)
# TODO annoatation probe per probe!!

return sliced_probe_group

Expand Down
60 changes: 51 additions & 9 deletions tests/test_probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import numpy as np


@pytest.fixture
def probegroup():
def _make_probegroup():
"""Fixture: a ProbeGroup with 3 probes, each with device channel indices set."""
probegroup = ProbeGroup()
nchan = 0
Expand All @@ -21,6 +20,11 @@ def probegroup():
return probegroup


@pytest.fixture
def probegroup():
return _make_probegroup()


def test_probegroup(probegroup):
indices = probegroup.get_global_device_channel_indices()

Expand Down Expand Up @@ -200,7 +204,7 @@ def test_copy_is_independent(probegroup):
np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions)


# ── get_slice() tests ───────────────────────────────────────────────────────
# ── get_slice() simple : natural order


def test_get_slice_by_bool(probegroup):
Expand Down Expand Up @@ -232,10 +236,10 @@ def test_get_slice_preserves_positions(probegroup):
np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected)


def test_get_slice_empty_selection(probegroup):
sliced = probegroup.get_slice(np.array([], dtype=int))
assert sliced.get_contact_count() == 0
assert len(sliced.probes) == 0
# def test_get_slice_empty_selection(probegroup):
# sliced = probegroup.get_slice(np.array([], dtype=int))
# assert sliced.get_contact_count() == 0
# assert len(sliced.probes) == 0


def test_get_slice_wrong_bool_size(probegroup):
Expand All @@ -260,6 +264,44 @@ def test_get_slice_all_contacts(probegroup):
)


# ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice


def test_reordred_probegroup(probegroup):
order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)])

contact_vector = probegroup.to_numpy(complete=True)
contact_vector = contact_vector[order]

probegroup2 = ProbeGroup.from_numpy(contact_vector)
assert probegroup2._global_contact_order is not None
contact_vector2 = probegroup2.to_numpy(complete=True)
assert np.array_equal(contact_vector, contact_vector2)

probegroup3 = ProbeGroup.from_dict(probegroup2.to_dict())
assert probegroup3._global_contact_order is not None
contact_vector3 = probegroup3.to_numpy(complete=True)
assert np.array_equal(contact_vector2, contact_vector3)

probegroup4 = probegroup.get_slice(order)
assert probegroup4._global_contact_order is not None
contact_vector4 = probegroup4.to_numpy(complete=True)
assert np.array_equal(contact_vector3, contact_vector4)

probegroup5 = ProbeGroup.from_dict(probegroup4.to_dict())
assert probegroup5._global_contact_order is not None
contact_vector5 = probegroup3.to_numpy(complete=True)
assert np.array_equal(contact_vector4, contact_vector5)

# let go back to original order
rev_order = np.argsort(order)
probegroup6 = probegroup5.get_slice(rev_order)
assert probegroup6._global_contact_order is None


if __name__ == "__main__":
test_probegroup()
# ~ test_probegroup_3d()
probegroup = _make_probegroup()

# test_probegroup(probegroup)
# test_probegroup_3d()
test_reordred_probegroup(probegroup)
Loading