Skip to content

Commit ea53a47

Browse files
authored
Merge pull request #19 from SpikeInterface/remove-shank
Remove Shank abstraction
2 parents e54c0f9 + 55c46b5 commit ea53a47

File tree

5 files changed

+90
-161
lines changed

5 files changed

+90
-161
lines changed

spec/ndx-probeinterface.extensions.yaml

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,6 @@ groups:
4040
- 3
4141
doc: The planar polygon that outlines the probe contour.
4242
groups:
43-
- neurodata_type_inc: Shank
44-
doc: Neural probe shank object according to probeinterface specification
45-
quantity: '*'
46-
- neurodata_type_def: Shank
47-
neurodata_type_inc: NWBContainer
48-
doc: Neural probe shanks according to probeinterface specification
49-
attributes:
50-
- name: shank_id
51-
dtype: text
52-
doc: ID of the shank in the probe; must be a str
53-
groups:
5443
- neurodata_type_inc: ContactTable
5544
doc: Neural probe contacts according to probeinterface specification
5645
- neurodata_type_def: ContactTable
@@ -80,6 +69,11 @@ groups:
8069
dtype: text
8170
doc: unique ID of the contact
8271
quantity: '?'
72+
- name: shank_id
73+
neurodata_type_inc: VectorData
74+
dtype: text
75+
doc: shank ID of the contact
76+
quantity: '?'
8377
- name: contact_plane_axes
8478
neurodata_type_inc: VectorData
8579
dtype: float

src/pynwb/ndx_probeinterface/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# TODO: import your classes here or define your class using get_class to make
1818
# them accessible at the package level
1919
Probe = get_class("Probe", "ndx-probeinterface")
20-
Shank = get_class("Shank", "ndx-probeinterface")
2120
ContactTable = get_class("ContactTable", "ndx-probeinterface")
2221

2322

src/pynwb/ndx_probeinterface/io.py

Lines changed: 42 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,6 @@ def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[D
3636
return devices
3737

3838

39-
def from_probegroup(probegroup: ProbeGroup):
40-
"""
41-
Construct ndx-probeinterface Probe devices from a probeinterface.ProbeGroup
42-
43-
Parameters
44-
----------
45-
probegroup: ProbeGroup
46-
ProbeGroup to convert to ndx-probeinterface Probe devices
47-
48-
Returns
49-
-------
50-
list
51-
List of ndx-probeinterface Probe devices
52-
"""
53-
assert isinstance(probegroup, ProbeGroup)
54-
devices = []
55-
for probe in probegroup.probes:
56-
devices.append(_single_probe_to_nwb_device(probe))
57-
return devices
58-
5939

6040
def to_probeinterface(ndx_probe) -> Probe:
6141
"""
@@ -84,29 +64,31 @@ def to_probeinterface(ndx_probe) -> Probe:
8464
device_channel_indices = None
8565

8666
possible_shape_keys = ["radius", "width", "height"]
87-
for shank in ndx_probe.shanks.values():
88-
positions.append(shank.contact_table["contact_position"][:])
89-
shapes.append(shank.contact_table["contact_shape"][:])
90-
if "contact_id" in shank.contact_table.colnames:
91-
if contact_ids is None:
92-
contact_ids = []
93-
contact_ids.append(shank.contact_table["contact_id"][:])
94-
if "device_channel_index_pi" in shank.contact_table.colnames:
95-
if device_channel_indices is None:
96-
device_channel_indices = []
97-
device_channel_indices.append(shank.contact_table["device_channel_index_pi"][:])
98-
if "contact_plane_axes" in shank.contact_table.colnames:
99-
if plane_axes is None:
100-
plane_axes = []
101-
plane_axes.append(shank.contact_table["contact_plane_axes"][:])
67+
contact_table = ndx_probe.contact_table
68+
69+
positions.append(contact_table["contact_position"][:])
70+
shapes.append(contact_table["contact_shape"][:])
71+
if "contact_id" in contact_table.colnames:
72+
if contact_ids is None:
73+
contact_ids = []
74+
contact_ids.append(contact_table["contact_id"][:])
75+
if "device_channel_index_pi" in contact_table.colnames:
76+
if device_channel_indices is None:
77+
device_channel_indices = []
78+
device_channel_indices.append(contact_table["device_channel_index_pi"][:])
79+
if "contact_plane_axes" in contact_table.colnames:
80+
if plane_axes is None:
81+
plane_axes = []
82+
plane_axes.append(contact_table["contact_plane_axes"][:])
83+
if "shank_id" in contact_table.colnames:
10284
if shank_ids is None:
10385
shank_ids = []
104-
shank_ids.append([str(shank.shank_id)] * len(shank.contact_table))
105-
for possible_shape_key in possible_shape_keys:
106-
if possible_shape_key in shank.contact_table.colnames:
107-
if shape_params is None:
108-
shape_params = []
109-
shape_params.append([{possible_shape_key: val} for val in shank.contact_table[possible_shape_key][:]])
86+
shank_ids.append(contact_table["shank_id"][:])
87+
for possible_shape_key in possible_shape_keys:
88+
if possible_shape_key in contact_table.colnames:
89+
if shape_params is None:
90+
shape_params = []
91+
shape_params.append([{possible_shape_key: val} for val in contact_table[possible_shape_key][:]])
11092

11193
positions = [item for sublist in positions for item in sublist]
11294
shapes = [item for sublist in shapes for item in sublist]
@@ -138,7 +120,6 @@ def _single_probe_to_nwb_device(probe: Probe):
138120
from pynwb import load_namespaces, get_class
139121

140122
Probe = get_class("Probe", "ndx-probeinterface")
141-
Shank = get_class("Shank", "ndx-probeinterface")
142123
ContactTable = get_class("ContactTable", "ndx-probeinterface")
143124

144125
contact_positions = probe.contact_positions
@@ -160,39 +141,25 @@ def _single_probe_to_nwb_device(probe: Probe):
160141
if k not in shape_keys:
161142
shape_keys.append(k)
162143

163-
shanks = []
164-
contact_tables = []
165-
for i_s, unique_shank in enumerate(unique_shanks):
166-
if shank_ids is not None:
167-
shank_indices = np.nonzero(shank_ids == unique_shank)[0]
168-
pi_shank = probe.get_shanks()[i_s]
169-
shank_name = f"Shank {pi_shank.shank_id}"
170-
shank_id = str(pi_shank.shank_id)
171-
else:
172-
shank_indices = np.arange(probe.get_contact_count())
173-
shank_name = "Shank 0"
174-
shank_id = "0"
175-
176-
contact_table = ContactTable(
177-
name="ContactTable",
178-
description="Contact Table for ProbeInterface",
179-
)
144+
contact_table = ContactTable(
145+
name="ContactTable",
146+
description="Contact Table for ProbeInterface",
147+
)
180148

181-
for index in shank_indices:
182-
kwargs = dict(
183-
contact_position=contact_positions[index],
184-
contact_plane_axes=contact_plane_axes[index],
185-
contact_id=contact_ids[index],
186-
contact_shape=contacts_arr["contact_shapes"][index],
187-
)
188-
for k in shape_keys:
189-
kwargs[k] = contacts_arr[k][index]
190-
if probe.device_channel_indices is not None:
191-
kwargs["device_channel_index_pi"] = probe.device_channel_indices[index]
192-
contact_table.add_row(kwargs)
193-
contact_tables.append(contact_table)
194-
shank = Shank(name=shank_name, shank_id=shank_id, contact_table=contact_table)
195-
shanks.append(shank)
149+
for index in np.arange(probe.get_contact_count()):
150+
kwargs = dict(
151+
contact_position=contact_positions[index],
152+
contact_plane_axes=contact_plane_axes[index],
153+
contact_id=contact_ids[index],
154+
contact_shape=contacts_arr["contact_shapes"][index],
155+
)
156+
for k in shape_keys:
157+
kwargs[k] = contacts_arr[k][index]
158+
if probe.device_channel_indices is not None:
159+
kwargs["device_channel_index_pi"] = probe.device_channel_indices[index]
160+
if probe.shank_ids is not None:
161+
kwargs["shank_id"] = probe.shank_ids[index]
162+
contact_table.add_row(kwargs)
196163

197164
if "serial_number" in probe.annotations:
198165
serial_number = probe.annotations["serial_number"]
@@ -209,13 +176,13 @@ def _single_probe_to_nwb_device(probe: Probe):
209176

210177
probe_device = Probe(
211178
name=probe.annotations["name"],
212-
shanks=shanks,
213179
model_name=model_name,
214180
serial_number=serial_number,
215181
manufacturer=manufacturer,
216182
ndim=probe.ndim,
217183
unit=unit_map[probe.si_units],
218184
planar_contour=planar_contour,
185+
contact_table=contact_table
219186
)
220187

221188
return probe_device

src/pynwb/tests/test_probe.py

Lines changed: 31 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pynwb.file import ElectrodeTable as get_electrode_table
1212
from pynwb.testing import TestCase, remove_test_file, AcquisitionH5IOMixin
1313

14-
from ndx_probeinterface import Probe, Shank, ContactTable
14+
from ndx_probeinterface import Probe, ContactTable
1515

1616

1717
def set_up_nwbfile():
@@ -64,12 +64,8 @@ def test_constructor_from_probe_single_shank(self):
6464
self.assertIsInstance(device, Device)
6565
self.assertIsInstance(device, Probe)
6666

67-
# assert correct attributes
68-
self.assertEqual(len(device.shanks), 1)
69-
7067
# properties
71-
shank_names = list(device.shanks.keys())
72-
contact_table = device.shanks[shank_names[0]].contact_table
68+
contact_table = device.contact_table
7369
probe_array = probe.to_numpy()
7470
np.testing.assert_array_equal(contact_table["contact_position"][:], probe.contact_positions)
7571
np.testing.assert_array_equal(contact_table["contact_shape"][:], probe_array["contact_shapes"])
@@ -79,44 +75,36 @@ def test_constructor_from_probe_single_shank(self):
7975
probe.set_device_channel_indices(device_channel_indices)
8076
devices_w_indices = Probe.from_probeinterface(probe)
8177
device_w_indices = devices_w_indices[0]
82-
shank_names = list(device_w_indices.shanks.keys())
83-
contact_table = device_w_indices.shanks[shank_names[0]].contact_table
78+
contact_table = device_w_indices.contact_table
8479
np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices)
8580

8681
def test_constructor_from_probe_multi_shank(self):
8782
"""Test that the constructor from Probe sets values as expected for multi-shank."""
8883

8984
probe = self.probe1
85+
probe_array = probe.to_numpy()
86+
87+
device_channel_indices = np.arange(probe.get_contact_count())
88+
probe.set_device_channel_indices(device_channel_indices)
9089
devices = Probe.from_probeinterface(probe)
9190
device = devices[0]
9291
# assert correct objects
9392
self.assertIsInstance(device, Device)
9493
self.assertIsInstance(device, Probe)
9594

96-
# assert correct attributes
97-
self.assertEqual(len(device.shanks), 2)
98-
99-
# properties
100-
shank_names = list(device.shanks.keys())
101-
probe_array = probe.to_numpy()
102-
103-
# set channel indices
104-
device_channel_indices = np.arange(probe.get_contact_count())
105-
probe.set_device_channel_indices(device_channel_indices)
106-
devices_w_indices = Probe.from_probeinterface(probe)
107-
device_w_indices = devices_w_indices[0]
108-
for i_s, shank_name in enumerate(shank_names):
109-
contact_table = device_w_indices.shanks[shank_name].contact_table
110-
pi_shank = probe.get_shanks()[i_s]
111-
np.testing.assert_array_equal(
112-
contact_table["contact_position"][:], probe.contact_positions[pi_shank.get_indices()]
113-
)
114-
np.testing.assert_array_equal(
115-
contact_table["contact_shape"][:], probe_array["contact_shapes"][pi_shank.get_indices()]
116-
)
117-
np.testing.assert_array_equal(
118-
contact_table["device_channel_index_pi"][:], device_channel_indices[pi_shank.get_indices()]
119-
)
95+
contact_table = device.contact_table
96+
np.testing.assert_array_equal(
97+
contact_table["contact_position"][:], probe.contact_positions
98+
)
99+
np.testing.assert_array_equal(
100+
contact_table["contact_shape"][:], probe_array["contact_shapes"]
101+
)
102+
np.testing.assert_array_equal(
103+
contact_table["device_channel_index_pi"][:], device_channel_indices
104+
)
105+
np.testing.assert_array_equal(
106+
contact_table["shank_id"][:], probe.shank_ids
107+
)
120108

121109
def test_constructor_from_probegroup(self):
122110
"""Test that the constructor from probegroup sets values as expected."""
@@ -134,28 +122,22 @@ def test_constructor_from_probegroup(self):
134122
self.assertIsInstance(device, Device)
135123
self.assertIsInstance(device, Probe)
136124

137-
# assert correct attributes
138-
self.assertEqual(len(device.shanks), shank_counts[i])
139-
140125
# properties
141-
shank_names = list(device.shanks.keys())
142126
probe_array = probe.to_numpy()
143127
# TODO fix
144128
device_channel_indices = probe.device_channel_indices
145129
# set channel indices
146-
for i_s, shank_name in enumerate(shank_names):
147-
contact_table = device.shanks[shank_name].contact_table
148-
pi_shank = probe.get_shanks()[i_s]
149-
np.testing.assert_array_equal(
150-
contact_table["contact_position"][:], probe.contact_positions[pi_shank.get_indices()]
151-
)
152-
np.testing.assert_array_equal(
153-
contact_table["contact_shape"][:], probe_array["contact_shapes"][pi_shank.get_indices()]
154-
)
155-
156-
np.testing.assert_array_equal(
157-
contact_table["device_channel_index_pi"][:], device_channel_indices[pi_shank.get_indices()]
158-
)
130+
contact_table = device.contact_table
131+
np.testing.assert_array_equal(
132+
contact_table["contact_position"][:], probe.contact_positions
133+
)
134+
np.testing.assert_array_equal(
135+
contact_table["contact_shape"][:], probe_array["contact_shapes"]
136+
)
137+
138+
np.testing.assert_array_equal(
139+
contact_table["device_channel_index_pi"][:], device_channel_indices
140+
)
159141

160142

161143
class TestProbeRoundtrip(TestCase):

src/spec/create_extension_spec.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def main():
2525
# TODO: define your new data types
2626
# see https://pynwb.readthedocs.io/en/latest/extensions.html#extending-nwb
2727
# for more information
28-
contact = NWBGroupSpec(
28+
contact_table = NWBGroupSpec(
2929
doc="Neural probe contacts according to probeinterface specification",
3030
datasets=[
3131
NWBDatasetSpec(
@@ -49,6 +49,13 @@ def main():
4949
neurodata_type_inc="VectorData",
5050
quantity="?",
5151
),
52+
NWBDatasetSpec(
53+
name="shank_id",
54+
doc="shank ID of the contact",
55+
dtype="text",
56+
neurodata_type_inc="VectorData",
57+
quantity="?",
58+
),
5259
NWBDatasetSpec(
5360
name="contact_plane_axes",
5461
doc="dimension of the probe",
@@ -90,26 +97,6 @@ def main():
9097
neurodata_type_inc="DynamicTable",
9198
neurodata_type_def="ContactTable",
9299
)
93-
shank = NWBGroupSpec(
94-
doc="Neural probe shanks according to probeinterface specification",
95-
attributes=[
96-
NWBAttributeSpec(
97-
name="shank_id",
98-
doc="ID of the shank in the probe; must be a str",
99-
dtype="text",
100-
required=True,
101-
),
102-
],
103-
groups=[
104-
NWBGroupSpec(
105-
doc="Neural probe contacts according to probeinterface specification",
106-
neurodata_type_inc="ContactTable",
107-
quantity=1,
108-
)
109-
],
110-
neurodata_type_inc="NWBContainer",
111-
neurodata_type_def="Shank",
112-
)
113100
probe = NWBGroupSpec(
114101
doc="Neural probe object according to probeinterface specification",
115102
attributes=[
@@ -138,9 +125,9 @@ def main():
138125
neurodata_type_def="Probe",
139126
groups=[
140127
NWBGroupSpec(
141-
doc="Neural probe shank object according to probeinterface specification",
142-
neurodata_type_inc="Shank",
143-
quantity="*",
128+
doc="Neural probe contacts according to probeinterface specification",
129+
neurodata_type_inc="ContactTable",
130+
quantity=1,
144131
)
145132
],
146133
datasets=[
@@ -154,7 +141,7 @@ def main():
154141
],
155142
)
156143

157-
new_data_types = [probe, shank, contact]
144+
new_data_types = [probe, contact_table]
158145

159146
# export the spec to yaml files in the spec folder
160147
output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "spec"))

0 commit comments

Comments
 (0)