Skip to content

Commit e54c0f9

Browse files
authored
Merge pull request #18 from SpikeInterface/change-from-functions
Change `from_*` functions
2 parents 16ef4c9 + 16bf57c commit e54c0f9

File tree

5 files changed

+71
-32
lines changed

5 files changed

+71
-32
lines changed

.github/workflows/full_tests.yml

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ on:
44
pull_request:
55
branches: [main]
66
types: [synchronize, opened, reopened]
7-
workflow_dispatch:
8-
schedule:
9-
- cron: "0 12 * * *" # Daily at noon UTC
107

118

129
jobs:
1310
build-and-test:
1411

15-
runs-on: [ubuntu-latest, 'windows-latest', 'macos-latest']
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
1617

1718
steps:
1819
- uses: actions/checkout@v3
@@ -27,9 +28,4 @@ jobs:
2728
pip install -e .
2829
- name: Pytest
2930
run: |
30-
pytest --cov=probeinterface --cov-report xml:./coverage.xml
31-
- uses: codecov/codecov-action@v3
32-
with:
33-
token: ${{ secrets.CODECOV_TOKEN }}
34-
fail_ci_if_error: true
35-
file: ./coverage.xml
31+
pytest -v

README.md

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,44 @@ pip install ndx_probeinterface
99

1010
## Usage
1111

12-
### Going from a `ndx_probeinterface.Probe` object to a `probeinterface.Probe` object
12+
### Going from a `probeinterface.Probe`/``ProbeGroup` object to a `ndx_probeinterface.Probe` object
1313
```python
1414
import ndx_probeinterface
15-
pi_probe = ndx_probeinterface.to_probeinterface(ndx_probe)
15+
16+
pi_probe = probeinterface.Probe(...)
17+
pi_probegroup = probeinterface.ProbeGroup()
18+
19+
# from_probeinterface always returns a list of ndx_probeinterface.Probe devices
20+
ndx_probes1 = ndx_probeinterface.from_probeinterface(pi_probe)
21+
ndx_probes2 = ndx_probeinterface.from_probeinterface(pi_probegroup)
22+
23+
ndx_probes = ndx_probes1.extend(ndx_probes2)
24+
25+
nwbfile = pynwb.NWBFile(...)
26+
27+
# add Probe as NWB Devices
28+
for ndx_probe in ndx_probes:
29+
nwbfile.add_device(ndx_probe)
1630
```
1731

18-
### Going from a `probeinterface.Probe` object to a `ndx_probeinterface.Probe` object
32+
### Going from a `ndx_probeinterface.Probe` object to a `probeinterface.Probe` object
1933
```python
2034
import ndx_probeinterface
21-
ndx_probe = ndx_probeinterface.from_probe(pi_probe)
35+
36+
# load ndx_probeinterface.Probe objects from NWB file
37+
io = pynwb.NWBH5IO(file_path, 'r', load_namespaces=True)
38+
nwbfile = io.read()
39+
40+
ndx_probes = []
41+
for device in nwbfile:
42+
if isinstance(device, ndx_probeinterface.Probe):
43+
ndx_probes.append(device)
44+
45+
# convert to probeinterface.Probe objects
46+
pi_probes = []
47+
for ndx_probe in ndx_probes:
48+
pi_probe = ndx_probeinterface.to_probeinterface(ndx_probe)
49+
pi_probes.append(pi_probe)
2250
```
2351

2452
## Future plans

src/pynwb/ndx_probeinterface/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222

2323

2424
# Add custom constructors
25-
from .io import from_probe, from_probegroup, to_probeinterface
25+
from .io import from_probeinterface, to_probeinterface
2626

27-
Probe.from_probe = from_probe
28-
Probe.from_probegroup = from_probegroup
27+
Probe.from_probeinterface = from_probeinterface
2928
Probe.to_probeinterface = to_probeinterface

src/pynwb/ndx_probeinterface/io.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Union, List, Optional
22
import numpy as np
33
from probeinterface import Probe, ProbeGroup
4+
from pynwb.file import Device
45

56
unit_map = {
67
"um": "micrometer",
@@ -10,7 +11,7 @@
1011
inverted_unit_map = {v: k for k, v in unit_map.items()}
1112

1213

13-
def from_probe(probe: Probe):
14+
def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[Device]:
1415
"""
1516
Construct ndx-probeinterface Probe devices from a probeinterface.Probe
1617
@@ -21,11 +22,18 @@ def from_probe(probe: Probe):
2122
2223
Returns
2324
-------
24-
devices: ndx_probeinterface.Probe
25-
The ndx-probeinterface Probe device
25+
devices: list
26+
The list of ndx-probeinterface Probe devices
2627
"""
27-
assert isinstance(probe, Probe)
28-
return _single_probe_to_nwb_device(probe)
28+
assert isinstance(probe_or_probegroup, (Probe, ProbeGroup)), f"The input must be a Probe or ProbeGroup, not {type(probe_or_probegroup)}"
29+
if isinstance(probe_or_probegroup, Probe):
30+
probes = [probe_or_probegroup]
31+
else:
32+
probes = probe_or_probegroup.probes
33+
devices = []
34+
for probe in probes:
35+
devices.append(_single_probe_to_nwb_device(probe))
36+
return devices
2937

3038

3139
def from_probegroup(probegroup: ProbeGroup):

src/pynwb/tests/test_probe.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def test_constructor_from_probe_single_shank(self):
5858
"""Test that the constructor from Probe sets values as expected for single-shank."""
5959

6060
probe = self.probe0
61-
device = Probe.from_probe(probe)
61+
devices = Probe.from_probeinterface(probe)
62+
device = devices[0]
6263
# assert correct objects
6364
self.assertIsInstance(device, Device)
6465
self.assertIsInstance(device, Probe)
@@ -76,7 +77,8 @@ def test_constructor_from_probe_single_shank(self):
7677
# set channel indices
7778
device_channel_indices = np.arange(probe.get_contact_count())
7879
probe.set_device_channel_indices(device_channel_indices)
79-
device_w_indices = Probe.from_probe(probe)
80+
devices_w_indices = Probe.from_probeinterface(probe)
81+
device_w_indices = devices_w_indices[0]
8082
shank_names = list(device_w_indices.shanks.keys())
8183
contact_table = device_w_indices.shanks[shank_names[0]].contact_table
8284
np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices)
@@ -85,7 +87,8 @@ def test_constructor_from_probe_multi_shank(self):
8587
"""Test that the constructor from Probe sets values as expected for multi-shank."""
8688

8789
probe = self.probe1
88-
device = Probe.from_probe(probe)
90+
devices = Probe.from_probeinterface(probe)
91+
device = devices[0]
8992
# assert correct objects
9093
self.assertIsInstance(device, Device)
9194
self.assertIsInstance(device, Probe)
@@ -100,7 +103,8 @@ def test_constructor_from_probe_multi_shank(self):
100103
# set channel indices
101104
device_channel_indices = np.arange(probe.get_contact_count())
102105
probe.set_device_channel_indices(device_channel_indices)
103-
device_w_indices = Probe.from_probe(probe)
106+
devices_w_indices = Probe.from_probeinterface(probe)
107+
device_w_indices = devices_w_indices[0]
104108
for i_s, shank_name in enumerate(shank_names):
105109
contact_table = device_w_indices.shanks[shank_name].contact_table
106110
pi_shank = probe.get_shanks()[i_s]
@@ -120,7 +124,7 @@ def test_constructor_from_probegroup(self):
120124
probegroup = self.probegroup
121125
global_device_channel_indices = np.arange(probegroup.get_channel_count())
122126
probegroup.set_global_device_channel_indices(global_device_channel_indices)
123-
devices = Probe.from_probegroup(probegroup)
127+
devices = Probe.from_probeinterface(probegroup)
124128
probes = probegroup.probes
125129
shank_counts = [1, 2]
126130

@@ -175,7 +179,8 @@ def tearDown(self):
175179
remove_test_file(path)
176180

177181
def test_roundtrip_nwb_from_probe_single_shank(self):
178-
device = Probe.from_probe(self.probe0)
182+
devices = Probe.from_probeinterface(self.probe0)
183+
device = devices[0]
179184
self.nwbfile0.add_device(device)
180185

181186
with NWBHDF5IO(self.path0, mode="w") as io:
@@ -187,7 +192,8 @@ def test_roundtrip_nwb_from_probe_single_shank(self):
187192
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
188193

189194
def test_roundtrip_nwb_from_probe_multi_shank(self):
190-
device = Probe.from_probe(self.probe1)
195+
devices = Probe.from_probeinterface(self.probe1)
196+
device = devices[0]
191197
self.nwbfile1.add_device(device)
192198

193199
with NWBHDF5IO(self.path1, mode="w") as io:
@@ -199,7 +205,7 @@ def test_roundtrip_nwb_from_probe_multi_shank(self):
199205
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
200206

201207
def test_roundtrip_nwb_from_probegroup(self):
202-
devices = Probe.from_probegroup(self.probegroup)
208+
devices = Probe.from_probeinterface(self.probegroup)
203209
for device in devices:
204210
self.nwbfile2.add_device(device)
205211

@@ -213,12 +219,14 @@ def test_roundtrip_nwb_from_probegroup(self):
213219

214220
def test_roundtrip_pi_from_probe_single_shank(self):
215221
probe_arr = self.probe0.to_numpy()
216-
device = Probe.from_probe(self.probe0)
222+
devices = Probe.from_probeinterface(self.probe0)
223+
device = devices[0]
217224
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
218225

219226
def test_roundtrip_pi_from_probe_multi_shank(self):
220227
probe_arr = self.probe1.to_numpy()
221-
device = Probe.from_probe(self.probe1)
228+
devices = Probe.from_probeinterface(self.probe1)
229+
device = devices[0]
222230
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
223231

224232

0 commit comments

Comments
 (0)