Skip to content

Commit 77849af

Browse files
authored
Merge pull request #16 from SpikeInterface/add_tests
Add tests for roundtrip
2 parents cbe12c7 + 8cec41a commit 77849af

8 files changed

Lines changed: 475 additions & 311 deletions

File tree

.github/workflows/full_tests.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Full tests
2+
3+
on:
4+
pull_request:
5+
branches: [main]
6+
types: [synchronize, opened, reopened]
7+
workflow_dispatch:
8+
schedule:
9+
- cron: "0 12 * * *" # Daily at noon UTC
10+
11+
12+
jobs:
13+
build-and-test:
14+
15+
runs-on: [ubuntu-latest, 'windows-latest', 'macos-latest']
16+
17+
steps:
18+
- uses: actions/checkout@v3
19+
- name: Set up Python
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: "3.11"
23+
- name: Install package
24+
run: |
25+
python -m pip install --upgrade pip
26+
pip install -r requirements-dev.txt
27+
pip install -e .
28+
- name: Pytest
29+
run: |
30+
pytest --cov=probeinterface --cov-report xml:./coverage.xml
31+
- uses: codecov/codecov-action@v3
32+
with:
33+
token: ${{ secrets.CODECOV_TOKEN }}
34+
fail_ci_if_error: true
35+
file: ./coverage.xml
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
name: Release to PyPI
2+
3+
on:
4+
push:
5+
tags:
6+
- '*'
7+
jobs:
8+
release:
9+
10+
runs-on: ubuntu-latest
11+
12+
steps:
13+
- uses: actions/checkout@v3
14+
- name: Set up Python 3.8
15+
uses: actions/setup-python@v4
16+
with:
17+
python-version: "3.10"
18+
- name: Install Tools
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install -r requirements-dev.txt
22+
pip install .
23+
pip install setuptools wheel twine build
24+
- name: Test with pytest
25+
run: |
26+
pytest -v
27+
- name: Package and Upload
28+
env:
29+
STACKMANAGER_VERSION: ${{ github.event.release.tag_name }}
30+
TWINE_USERNAME: __token__
31+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
32+
run: |
33+
python -m build --sdist --wheel
34+
twine upload dist/*

spec/ndx-probeinterface.extensions.yaml

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,29 @@ groups:
5757
neurodata_type_inc: DynamicTable
5858
doc: Neural probe contacts according to probeinterface specification
5959
datasets:
60-
- name: contact_id
60+
- name: contact_position
6161
neurodata_type_inc: VectorData
62-
dtype: text
63-
doc: unique ID of the contact
62+
dtype: float
63+
dims:
64+
- - num_contacts
65+
- x, y
66+
- - num_contacts
67+
- x, y, z
68+
shape:
69+
- - null
70+
- 2
71+
- - null
72+
- 3
73+
doc: dimension of the probe
6474
- name: contact_shape
6575
neurodata_type_inc: VectorData
6676
dtype: text
6777
doc: shape of the contact; e.g. 'circle'
78+
- name: contact_id
79+
neurodata_type_inc: VectorData
80+
dtype: text
81+
doc: unique ID of the contact
82+
quantity: '?'
6883
- name: contact_plane_axes
6984
neurodata_type_inc: VectorData
7085
dtype: float
@@ -83,21 +98,24 @@ groups:
8398
- 2
8499
- 3
85100
doc: dimension of the probe
86-
- name: contact_position
101+
quantity: '?'
102+
- name: radius
87103
neurodata_type_inc: VectorData
88104
dtype: float
89-
dims:
90-
- - num_contacts
91-
- x, y
92-
- - num_contacts
93-
- x, y, z
94-
shape:
95-
- - null
96-
- 2
97-
- - null
98-
- 3
99-
doc: dimension of the probe
100-
- name: device_channel_index
105+
doc: Radius of a circular contact
106+
quantity: '?'
107+
- name: width
108+
neurodata_type_inc: VectorData
109+
dtype: float
110+
doc: Width of a rectangular or square contact
111+
quantity: '?'
112+
- name: height
113+
neurodata_type_inc: VectorData
114+
dtype: float
115+
doc: Height of a rectangular contact
116+
quantity: '?'
117+
- name: device_channel_index_pi
101118
neurodata_type_inc: VectorData
102119
dtype: int
103120
doc: ID of the channel connected to the contact
121+
quantity: '?'

src/pynwb/ndx_probeinterface/__init__.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,28 @@
22
from pynwb import load_namespaces, get_class
33

44
# Set path of the namespace.yaml file to the expected install location
5-
ndx_probeinterface_specpath = os.path.join(
6-
os.path.dirname(__file__),
7-
'spec',
8-
'ndx-probeinterface.namespace.yaml'
9-
)
5+
ndx_probeinterface_specpath = os.path.join(os.path.dirname(__file__), "spec", "ndx-probeinterface.namespace.yaml")
106

117
# If the extension has not been installed yet but we are running directly from
128
# the git repo
139
if not os.path.exists(ndx_probeinterface_specpath):
14-
ndx_probeinterface_specpath = os.path.abspath(os.path.join(
15-
os.path.dirname(__file__),
16-
'..', '..', '..',
17-
'spec',
18-
'ndx-probeinterface.namespace.yaml'
19-
))
10+
ndx_probeinterface_specpath = os.path.abspath(
11+
os.path.join(os.path.dirname(__file__), "..", "..", "..", "spec", "ndx-probeinterface.namespace.yaml")
12+
)
2013

2114
# Load the namespace
2215
load_namespaces(ndx_probeinterface_specpath)
2316

2417
# TODO: import your classes here or define your class using get_class to make
2518
# them accessible at the package level
26-
Probe = get_class('Probe', 'ndx-probeinterface')
27-
Shank = get_class('Shank', 'ndx-probeinterface')
28-
ContactTable = get_class('ContactTable', 'ndx-probeinterface')
19+
Probe = get_class("Probe", "ndx-probeinterface")
20+
Shank = get_class("Shank", "ndx-probeinterface")
21+
ContactTable = get_class("ContactTable", "ndx-probeinterface")
2922

3023

3124
# Add custom constructors
3225
from .io import from_probe, from_probegroup, to_probeinterface
26+
3327
Probe.from_probe = from_probe
3428
Probe.from_probegroup = from_probegroup
35-
Probe.to_probeinterface = to_probeinterface
29+
Probe.to_probeinterface = to_probeinterface

src/pynwb/ndx_probeinterface/io.py

Lines changed: 67 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
}
1010
inverted_unit_map = {v: k for k, v in unit_map.items()}
1111

12-
shape_words = ['radius', 'width', 'width/height']
1312

1413
def from_probe(probe: Probe):
1514
"""
@@ -19,10 +18,10 @@ def from_probe(probe: Probe):
1918
----------
2019
probe_or_probegroup: Probe or ProbeGroup
2120
Probe or ProbeGroup to convert to ndx-probeinterface Probe devices
22-
21+
2322
Returns
2423
-------
25-
devices: ndx_probeinterface.Probe
24+
devices: ndx_probeinterface.Probe
2625
The ndx-probeinterface Probe device
2726
"""
2827
assert isinstance(probe, Probe)
@@ -50,69 +49,89 @@ def from_probegroup(probegroup: ProbeGroup):
5049
return devices
5150

5251

53-
def to_probeinterface(ndx_Probe)->Probe:
52+
def to_probeinterface(ndx_probe) -> Probe:
5453
"""
5554
Construct a probeinterface.Probe from ndx_probeinterface.Probe
5655
5756
Parameters
5857
----------
59-
ndx_Probe: ndx_probeinterface.Probe
60-
ndx_probeinterface.Probe to convert to probeinterface.Probe
61-
58+
ndx_probe: ndx_probeinterface.Probe
59+
ndx_probeinterface.Probe to convert to probeinterface.Probe
60+
6261
Returns
6362
-------
64-
Probe: probeinterface.Probe
63+
Probe: probeinterface.Probe
6564
"""
66-
ndim = ndx_Probe.ndim
67-
unit = inverted_unit_map[ndx_Probe.unit]
68-
polygon = ndx_Probe.planar_contour
69-
65+
ndim = ndx_probe.ndim
66+
unit = inverted_unit_map[ndx_probe.unit]
67+
polygon = ndx_probe.planar_contour
68+
7069
positions = []
71-
contact_ids = []
7270
shapes = []
73-
shape_params = []
74-
shank_ids = []
75-
plane_axes = []
76-
channel_indices = []
77-
for shank in ndx_Probe.shanks.values():
78-
positions.append(shank.contact_table['contact_position'][:])
79-
contact_ids.append(shank.contact_table['contact_id'][:])
80-
shapes.append(shank.contact_table['contact_shape'][:])
81-
channel_indices.append(shank.contact_table['device_channel_index'][:])
82-
plane_axes.append(shank.contact_table['contact_plane_axes'][:])
83-
shank_ids.append([int(shank.shank_id)] * len(shank.contact_table))
84-
# WARNING: currently assumes that all the contacts have the same shape
85-
shape_word = [shape for shape in shape_words if shape in shank.contact_table[:].columns][0]
86-
shape_params.append([{shape_word: val} for val in shank.contact_table[shape_word][:]])
71+
72+
contact_ids = None
73+
shape_params = None
74+
shank_ids = None
75+
plane_axes = None
76+
device_channel_indices = None
77+
78+
possible_shape_keys = ["radius", "width", "height"]
79+
for shank in ndx_probe.shanks.values():
80+
positions.append(shank.contact_table["contact_position"][:])
81+
shapes.append(shank.contact_table["contact_shape"][:])
82+
if "contact_id" in shank.contact_table.colnames:
83+
if contact_ids is None:
84+
contact_ids = []
85+
contact_ids.append(shank.contact_table["contact_id"][:])
86+
if "device_channel_index_pi" in shank.contact_table.colnames:
87+
if device_channel_indices is None:
88+
device_channel_indices = []
89+
device_channel_indices.append(shank.contact_table["device_channel_index_pi"][:])
90+
if "contact_plane_axes" in shank.contact_table.colnames:
91+
if plane_axes is None:
92+
plane_axes = []
93+
plane_axes.append(shank.contact_table["contact_plane_axes"][:])
94+
if shank_ids is None:
95+
shank_ids = []
96+
shank_ids.append([str(shank.shank_id)] * len(shank.contact_table))
97+
for possible_shape_key in possible_shape_keys:
98+
if possible_shape_key in shank.contact_table.colnames:
99+
if shape_params is None:
100+
shape_params = []
101+
shape_params.append([{possible_shape_key: val} for val in shank.contact_table[possible_shape_key][:]])
87102

88103
positions = [item for sublist in positions for item in sublist]
89-
contact_ids = [item for sublist in contact_ids for item in sublist]
90104
shapes = [item for sublist in shapes for item in sublist]
91-
plane_axes = [item for sublist in plane_axes for item in sublist]
92-
shank_ids = [item for sublist in shank_ids for item in sublist]
93-
channel_indices = [item for sublist in channel_indices for item in sublist]
94-
shape_params = [item for sublist in shape_params for item in sublist]
95105

96-
probeinterface_Probe = Probe(ndim=ndim, si_units=unit)
97-
probeinterface_Probe.set_contacts(positions=positions,
98-
shapes=shapes,
99-
shape_params=shape_params,
100-
plane_axes=plane_axes,
101-
shank_ids=shank_ids)
102-
probeinterface_Probe.set_contact_ids(contact_ids=contact_ids)
103-
probeinterface_Probe.set_device_channel_indices(channel_indices=channel_indices)
106+
if contact_ids is not None:
107+
contact_ids = [item for sublist in contact_ids for item in sublist]
108+
if plane_axes is not None:
109+
plane_axes = [item for sublist in plane_axes for item in sublist]
110+
if shape_params is not None:
111+
shape_params = [item for sublist in shape_params for item in sublist]
112+
if shank_ids is not None:
113+
shank_ids = [item for sublist in shank_ids for item in sublist]
114+
if device_channel_indices is not None:
115+
device_channel_indices = [item for sublist in channel_indices for item in sublist]
104116

105-
probeinterface_Probe.set_planar_contour(polygon)
117+
probeinterface_probe = Probe(ndim=ndim, si_units=unit)
118+
probeinterface_probe.set_contacts(
119+
positions=positions, shapes=shapes, shape_params=shape_params, plane_axes=plane_axes, shank_ids=shank_ids
120+
)
121+
probeinterface_probe.set_contact_ids(contact_ids=contact_ids)
122+
if device_channel_indices is not None:
123+
probeinterface_probe.set_device_channel_indices(channel_indices=device_channel_indices)
124+
probeinterface_probe.set_planar_contour(polygon)
106125

107-
return probeinterface_Probe
126+
return probeinterface_probe
108127

109128

110129
def _single_probe_to_nwb_device(probe: Probe):
111130
from pynwb import load_namespaces, get_class
112131

113-
Probe = get_class('Probe', 'ndx-probeinterface')
114-
Shank = get_class('Shank', 'ndx-probeinterface')
115-
ContactTable = get_class('ContactTable', 'ndx-probeinterface')
132+
Probe = get_class("Probe", "ndx-probeinterface")
133+
Shank = get_class("Shank", "ndx-probeinterface")
134+
ContactTable = get_class("ContactTable", "ndx-probeinterface")
116135

117136
contact_positions = probe.contact_positions
118137
contact_plane_axes = probe.contact_plane_axes
@@ -150,13 +169,6 @@ def _single_probe_to_nwb_device(probe: Probe):
150169
name="ContactTable",
151170
description="Contact Table for ProbeInterface",
152171
)
153-
154-
if probe.device_channel_indices is not None:
155-
contact_table.add_column(name="device_channel_index",
156-
description="Device channel index")
157-
for k in shape_keys:
158-
contact_table.add_column(name=k,
159-
description="Shape parameter for electrode")
160172

161173
for index in shank_indices:
162174
kwargs = dict(
@@ -168,12 +180,10 @@ def _single_probe_to_nwb_device(probe: Probe):
168180
for k in shape_keys:
169181
kwargs[k] = contacts_arr[k][index]
170182
if probe.device_channel_indices is not None:
171-
kwargs["device_channel_index"] = probe.device_channel_indices[index]
183+
kwargs["device_channel_index_pi"] = probe.device_channel_indices[index]
172184
contact_table.add_row(kwargs)
173185
contact_tables.append(contact_table)
174-
shank = Shank(name=shank_name,
175-
shank_id=shank_id,
176-
contact_table=contact_table)
186+
shank = Shank(name=shank_name, shank_id=shank_id, contact_table=contact_table)
177187
shanks.append(shank)
178188

179189
if "serial_number" in probe.annotations:
@@ -197,7 +207,7 @@ def _single_probe_to_nwb_device(probe: Probe):
197207
manufacturer=manufacturer,
198208
ndim=probe.ndim,
199209
unit=unit_map[probe.si_units],
200-
planar_contour=planar_contour
210+
planar_contour=planar_contour,
201211
)
202212

203213
return probe_device

0 commit comments

Comments
 (0)