diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index d42906a4..26f23bff 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -7,12 +7,21 @@ class ProbeGroup: """ Class to handle a group of Probe objects and the global wiring to a device. - Optionally, it can handle the location of different probes. + Internally, this is represented as a list of Probe object. + + The ProbeGroup is the object saved in the json based probeinterface format, even if there only one probe. + + Tiny detail: when using `PropbeGroup.to_numpy()` / `PropbeGroup.to_dataframe()` by default the contact order + is the "natural" one (stacked order of each probe). But optionally, this order can be more complex, for instance + some contact of each probe are interleaved, in this case a optional reordering can be applied. + + """ def __init__(self): self.probes = [] + self._global_contact_order = None def add_probe(self, probe: Probe) -> None: """ @@ -114,6 +123,9 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: pg_arr.append(arr_ext) pg_arr = np.concatenate(pg_arr, axis=0) + + if self._global_contact_order is not None: + pg_arr = pg_arr[self._global_contact_order] return pg_arr @staticmethod @@ -121,6 +133,10 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": """Create ProbeGroup from a complex numpy array see ProbeGroup.to_numpy() + Note that if the contact_vector has several probe and some contact are interleaved, then the ProbeGroup will + have a non natural ordering (contact from probes are stack): in short ProbeGroup._global_contact_order + will be not None. + Parameters ---------- arr : np.array @@ -131,7 +147,12 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": probegroup : ProbeGroup The instantiated ProbeGroup object """ - from .probe import Probe + + # Check if contacts are interleaved + num_probes = np.unique(arr["probe_index"]).size + is_interleaved = (num_probes > 1) and np.any(np.diff(arr["probe_index"]) < 0) + if is_interleaved: + global_contact_order = [] probes_indices = np.unique(arr["probe_index"]) probegroup = ProbeGroup() @@ -139,6 +160,14 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": mask = arr["probe_index"] == probe_index probe = Probe.from_numpy(arr[mask]) probegroup.add_probe(probe) + + if is_interleaved: + global_contact_order.append(np.flatnonzero(mask)) + + if is_interleaved: + # the argsort is for the 'reverse' order! + probegroup._global_contact_order = np.argsort(np.concatenate(global_contact_order)) + return probegroup def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": @@ -181,6 +210,11 @@ def to_dict(self, array_as_list: bool = False) -> dict: for probe_ind, probe in enumerate(self.probes): probe_dict = probe.to_dict(array_as_list=array_as_list) d["probes"].append(probe_dict) + if self._global_contact_order is not None: + global_contact_order = self._global_contact_order + if array_as_list: + global_contact_order = global_contact_order.to_list() + d["global_contact_order"] = global_contact_order return d @staticmethod @@ -201,6 +235,11 @@ def from_dict(d: dict) -> "ProbeGroup": for probe_dict in d["probes"]: probe = Probe.from_dict(probe_dict) probegroup.add_probe(probe) + + global_contact_order = d.get("global_contact_order", None) + if global_contact_order is not None: + probegroup._global_contact_order = np.asarray(global_contact_order) + return probegroup def get_global_device_channel_indices(self) -> np.ndarray: @@ -226,19 +265,23 @@ def get_global_device_channel_indices(self) -> np.ndarray: channels["device_channel_indices"] = arr["device_channel_indices"] return channels - def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None: + def set_global_device_channel_indices(self, device_channel_indices: np.ndarray | list) -> None: """ - Set global indices for all probes + Set global indices for all probes. + + Important note : if the order of contacts is not "natural" then the device_channel_indices + is applied is the real/reordered contacts vector. In short, the device_channel_indices is zipped to + ProbeGroup.to_numpy() (always ordered). Parameters ---------- channels: np.ndarray | list The device channal indices to be set """ - channels = np.asarray(channels) - if channels.size != self.get_contact_count(): + device_channel_indices = np.asarray(device_channel_indices) + if device_channel_indices.size != self.get_contact_count(): raise ValueError( - f"Wrong channels size {channels.size} for the number of channels {self.get_contact_count()}" + f"Wrong channels size {device_channel_indices.size} for the number of channels {self.get_contact_count()}" ) # first reset previous indices @@ -246,11 +289,16 @@ def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None n = probe.get_contact_count() probe.set_device_channel_indices([-1] * n) + if self._global_contact_order is not None: + # this is tricky conceptually but needed for consistency + rev_order = np.argsort(self._global_contact_order) + device_channel_indices = device_channel_indices[rev_order] + # then set new indices ind = 0 for i, probe in enumerate(self.probes): n = probe.get_contact_count() - probe.set_device_channel_indices(channels[ind : ind + n]) + probe.set_device_channel_indices(device_channel_indices[ind : ind + n]) ind += n def get_global_contact_ids(self) -> np.ndarray: @@ -275,6 +323,8 @@ def get_global_contact_positions(self) -> np.ndarray: An array of the contact positions across all probes """ contact_positions = np.vstack([probe.contact_positions for probe in self.probes]) + if self._global_contact_order is not None: + contact_positions = contact_positions[self._global_contact_order] return contact_positions def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": @@ -299,48 +349,28 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": n = self.get_contact_count() selection = np.asarray(selection) + if selection.dtype.kind not in ("b", "i"): + raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + if selection.dtype == "bool": assert selection.shape == ( n, ), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}" - (selection_indices,) = np.nonzero(selection) - elif selection.dtype.kind == "i": - assert np.unique(selection).size == selection.size - if len(selection) > 0: - assert ( - 0 <= np.min(selection) < n - ), f"An index within your selection is out of bounds {np.min(selection)}" - assert ( - 0 <= np.max(selection) < n - ), f"An index within your selection is out of bounds {np.max(selection)}" - selection_indices = selection - else: - selection_indices = [] - else: - raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}") + selection = np.flatnonzero(selection) - if len(selection_indices) == 0: - return ProbeGroup() + if len(selection) == 0: + raise ValueError("ProbeGroup.get_slice() with empty selection is not handled") + # return ProbeGroup() - # Map selection to indices of individual probes - ind = 0 - sliced_probes = [] - for probe in self.probes: - n = probe.get_contact_count() - probe_limits = (ind, ind + n) - ind += n + assert np.unique(selection).size == selection.size + assert 0 <= np.min(selection) < n, f"An index within your selection is out of bounds {np.min(selection)}" + assert 0 <= np.max(selection) < n, f"An index within your selection is out of bounds {np.max(selection)}" + + contact_arr = self.to_numpy(complete=True) + contact_arr = contact_arr[selection] + sliced_probe_group = ProbeGroup.from_numpy(contact_arr) - probe_selection_indices = selection_indices[ - (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) - ] - if len(probe_selection_indices) == 0: - continue - sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) - sliced_probes.append(sliced_probe) - - sliced_probe_group = ProbeGroup() - for probe in sliced_probes: - sliced_probe_group.add_probe(probe) + # TODO annoatation probe per probe!! return sliced_probe_group diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index ddd332d4..089c642a 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -6,8 +6,7 @@ import numpy as np -@pytest.fixture -def probegroup(): +def _make_probegroup(): """Fixture: a ProbeGroup with 3 probes, each with device channel indices set.""" probegroup = ProbeGroup() nchan = 0 @@ -21,6 +20,11 @@ def probegroup(): return probegroup +@pytest.fixture +def probegroup(): + return _make_probegroup() + + def test_probegroup(probegroup): indices = probegroup.get_global_device_channel_indices() @@ -200,7 +204,7 @@ def test_copy_is_independent(probegroup): np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions) -# ── get_slice() tests ─────────────────────────────────────────────────────── +# ── get_slice() simple : natural order def test_get_slice_by_bool(probegroup): @@ -232,10 +236,10 @@ def test_get_slice_preserves_positions(probegroup): np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected) -def test_get_slice_empty_selection(probegroup): - sliced = probegroup.get_slice(np.array([], dtype=int)) - assert sliced.get_contact_count() == 0 - assert len(sliced.probes) == 0 +# def test_get_slice_empty_selection(probegroup): +# sliced = probegroup.get_slice(np.array([], dtype=int)) +# assert sliced.get_contact_count() == 0 +# assert len(sliced.probes) == 0 def test_get_slice_wrong_bool_size(probegroup): @@ -260,6 +264,44 @@ def test_get_slice_all_contacts(probegroup): ) +# ── global_contact_order : to_numpy/from_numpy, to_dict/from_dict, get_slice + + +def test_reordred_probegroup(probegroup): + order = np.concatenate([np.arange(0, 96, 2), np.arange(95, 0, -2)]) + + contact_vector = probegroup.to_numpy(complete=True) + contact_vector = contact_vector[order] + + probegroup2 = ProbeGroup.from_numpy(contact_vector) + assert probegroup2._global_contact_order is not None + contact_vector2 = probegroup2.to_numpy(complete=True) + assert np.array_equal(contact_vector, contact_vector2) + + probegroup3 = ProbeGroup.from_dict(probegroup2.to_dict()) + assert probegroup3._global_contact_order is not None + contact_vector3 = probegroup3.to_numpy(complete=True) + assert np.array_equal(contact_vector2, contact_vector3) + + probegroup4 = probegroup.get_slice(order) + assert probegroup4._global_contact_order is not None + contact_vector4 = probegroup4.to_numpy(complete=True) + assert np.array_equal(contact_vector3, contact_vector4) + + probegroup5 = ProbeGroup.from_dict(probegroup4.to_dict()) + assert probegroup5._global_contact_order is not None + contact_vector5 = probegroup3.to_numpy(complete=True) + assert np.array_equal(contact_vector4, contact_vector5) + + # let go back to original order + rev_order = np.argsort(order) + probegroup6 = probegroup5.get_slice(rev_order) + assert probegroup6._global_contact_order is None + + if __name__ == "__main__": - test_probegroup() - # ~ test_probegroup_3d() + probegroup = _make_probegroup() + + # test_probegroup(probegroup) + # test_probegroup_3d() + test_reordred_probegroup(probegroup)