Skip to content

Commit 2a1b9fe

Browse files
committed
Add roundtrip tests
1 parent caa20f1 commit 2a1b9fe

2 files changed

Lines changed: 60 additions & 26 deletions

File tree

src/pynwb/ndx_probeinterface/io.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,38 +69,60 @@ def to_probeinterface(ndx_probeinterface_probe) -> Probe:
6969
polygon = ndx_probeinterface_probe.planar_contour
7070

7171
positions = []
72-
contact_ids = []
7372
shapes = []
74-
shape_params = []
75-
shank_ids = []
76-
plane_axes = []
77-
channel_indices = []
73+
74+
contact_ids = None
75+
shape_params = None
76+
shank_ids = None
77+
plane_axes = None
78+
device_channel_indices = None
79+
80+
possible_shape_keys = ["radius", "width", "height"]
7881
for shank in ndx_probeinterface_probe.shanks.values():
7982
positions.append(shank.contact_table["contact_position"][:])
80-
contact_ids.append(shank.contact_table["contact_id"][:])
8183
shapes.append(shank.contact_table["contact_shape"][:])
82-
channel_indices.append(shank.contact_table["device_channel_index_pi"][:])
83-
plane_axes.append(shank.contact_table["contact_plane_axes"][:])
84-
shank_ids.append([int(shank.shank_id)] * len(shank.contact_table))
85-
# WARNING: currently assumes that all the contacts have the same shape
86-
shape_word = [shape for shape in shape_words if shape in shank.contact_table[:].columns][0]
87-
shape_params.append([{shape_word: val} for val in shank.contact_table[shape_word][:]])
84+
if "contact_id" in shank.contact_table.colnames:
85+
if contact_ids is None:
86+
contact_ids = []
87+
contact_ids.append(shank.contact_table["contact_id"][:])
88+
if "device_channel_index_pi" in shank.contact_table.colnames:
89+
if device_channel_indices is None:
90+
device_channel_indices = []
91+
device_channel_indices.append(shank.contact_table["device_channel_index_pi"][:])
92+
if "contact_plane_axes" in shank.contact_table.colnames:
93+
if plane_axes is None:
94+
plane_axes = []
95+
plane_axes.append(shank.contact_table["contact_plane_axes"][:])
96+
if shank_ids is None:
97+
shank_ids = []
98+
shank_ids.append([str(shank.shank_id)] * len(shank.contact_table))
99+
for possible_shape_key in possible_shape_keys:
100+
if possible_shape_key in shank.contact_table.colnames:
101+
if shape_params is None:
102+
shape_params = []
103+
shape_params.append([{possible_shape_key: val} for val in shank.contact_table[possible_shape_key][:]])
88104

89105
positions = [item for sublist in positions for item in sublist]
90-
contact_ids = [item for sublist in contact_ids for item in sublist]
91106
shapes = [item for sublist in shapes for item in sublist]
92-
plane_axes = [item for sublist in plane_axes for item in sublist]
93-
shank_ids = [item for sublist in shank_ids for item in sublist]
94-
channel_indices = [item for sublist in channel_indices for item in sublist]
95-
shape_params = [item for sublist in shape_params for item in sublist]
107+
108+
if contact_ids is not None:
109+
contact_ids = [item for sublist in contact_ids for item in sublist]
110+
if plane_axes is not None:
111+
plane_axes = [item for sublist in plane_axes for item in sublist]
112+
if shape_params is not None:
113+
shape_params = [item for sublist in shape_params for item in sublist]
114+
if shank_ids is not None:
115+
shank_ids = [item for sublist in shank_ids for item in sublist]
116+
if device_channel_indices is not None:
117+
device_channel_indices = [item for sublist in channel_indices for item in sublist]
96118

97119
probeinterface_probe = Probe(ndim=ndim, si_units=unit)
98120
probeinterface_probe.set_contacts(
99121
positions=positions, shapes=shapes, shape_params=shape_params, plane_axes=plane_axes, shank_ids=shank_ids
100122
)
101123
probeinterface_probe.set_contact_ids(contact_ids=contact_ids)
102-
probeinterface_probe.set_device_channel_indices(channel_indices=channel_indices)
103-
124+
if device_channel_indices is not None:
125+
probeinterface_probe.set_device_channel_indices(channel_indices=device_channel_indices)
104126
probeinterface_probe.set_planar_contour(polygon)
105127

106128
return probeinterface_probe

src/pynwb/tests/test_probe.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def tearDown(self):
174174
for path in [self.path0, self.path1, self.path2]:
175175
remove_test_file(path)
176176

177-
def test_roundtrip_from_probe_single_shank(self):
177+
def test_roundtrip_nwb_from_probe_single_shank(self):
178178
device = Probe.from_probe(self.probe0)
179179
self.nwbfile0.add_device(device)
180180

@@ -184,9 +184,9 @@ def test_roundtrip_from_probe_single_shank(self):
184184
with NWBHDF5IO(self.path0, mode="r", load_namespaces=True) as io:
185185
read_nwbfile = io.read()
186186
devices = read_nwbfile.devices
187-
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
187+
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
188188

189-
def test_roundtrip_from_probe_multi_shank(self):
189+
def test_roundtrip_nwb_from_probe_multi_shank(self):
190190
device = Probe.from_probe(self.probe1)
191191
self.nwbfile1.add_device(device)
192192

@@ -198,7 +198,7 @@ def test_roundtrip_from_probe_multi_shank(self):
198198
devices = read_nwbfile.devices
199199
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
200200

201-
def test_roundtrip_from_probegroup(self):
201+
def test_roundtrip_nwb_from_probegroup(self):
202202
devices = Probe.from_probegroup(self.probegroup)
203203
for device in devices:
204204
self.nwbfile2.add_device(device)
@@ -211,6 +211,16 @@ def test_roundtrip_from_probegroup(self):
211211
for device in devices:
212212
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
213213

214+
def test_roundtrip_pi_from_probe_single_shank(self):
215+
probe_arr = self.probe0.to_numpy()
216+
device = Probe.from_probe(self.probe0)
217+
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
218+
219+
def test_roundtrip_pi_from_probe_multi_shank(self):
220+
probe_arr = self.probe1.to_numpy()
221+
device = Probe.from_probe(self.probe1)
222+
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
223+
214224

215225
if __name__ == "__main__":
216226
# test = TestProbeConstructors()
@@ -223,6 +233,8 @@ def test_roundtrip_from_probegroup(self):
223233
test1.setUp()
224234
# test.test_constructor_from_probe_single_shank()
225235
# test.test_constructor_from_probe_multi_shank()
226-
test1.test_roundtrip_from_probe_single_shank()
227-
test1.test_roundtrip_from_probe_multi_shank()
228-
test1.test_roundtrip_from_probegroup()
236+
test1.test_roundtrip_pi_from_probe_single_shank()
237+
test1.test_roundtrip_pi_from_probe_multi_shank()
238+
239+
# test1.test_roundtrip_from_probe_multi_shank()
240+
# test1.test_roundtrip_from_probegroup()

0 commit comments

Comments
 (0)