Skip to content
107 changes: 56 additions & 51 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from probeinterface import generate_multi_columns_probe
from probeinterface import generate_multi_columns_probe, get_probe, generate_tetrode

from spikeinterface import Templates
from spikeinterface.core import ms_to_samples
Expand All @@ -24,52 +24,56 @@
from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording
from .noise_tools import generate_noise

# this should be moved in probeinterface but later
_toy_probes = {
"Neuropixels1-384": dict(
num_columns=4,
num_contact_per_column=[96] * 4,
xpitch=16,
ypitch=40,
y_shift_per_column=[20, 0, 20, 0],
contact_shapes="square",
contact_shape_params={"width": 12},
),
"Neuropixels2-384": dict(
num_columns=2,
num_contact_per_column=[192] * 2,
xpitch=32,
ypitch=15,
contact_shapes="square",
contact_shape_params={"width": 12},
),
"Neuropixels2-128": dict(
num_columns=2,
num_contact_per_column=[64] * 2,
xpitch=32,
ypitch=15,
contact_shapes="square",
contact_shape_params={"width": 12},
),
"Neuropixels1-128": dict(
num_columns=4,
num_contact_per_column=[32] * 4,
xpitch=16,
ypitch=40,
y_shift_per_column=[20, 0, 20, 0],
contact_shapes="square",
contact_shape_params={"width": 12},
),
"Neuronexus-32": dict(
num_columns=3,
num_contact_per_column=[10, 12, 10],
xpitch=30,
ypitch=30,
y_shift_per_column=[0, -15, 0],
contact_shapes="circle",
contact_shape_params={"radius": 8},
),
}

def _make_probe_by_name(probe_name: str):
"""
Generates a probe from probeinterface library, using the manufacturer name combined with the probe name, e.g.
- 'cambridgeneurotech/ASSY-37-H7b'
- 'cambridgeneurotech#ASSY-37-H7b'
- 'imec#NP1000"

This function replace the old `_toy_probes` dict that generate probe using `generate_multi_columns_probe()`
"""
if probe_name == "Neuropixels1-384":
probe = get_probe("imec", "NP1000")
probe = probe.get_slice(np.arange(384))
elif probe_name == "Neuropixels1-128":
probe = get_probe("imec", "NP1000")
probe = probe.get_slice(np.arange(128))
elif probe_name == "Neuropixels2-384":
probe = get_probe("imec", "NP2000")
probe = probe.get_slice(np.arange(384))
elif probe_name == "Neuropixels2-128":
probe = get_probe("imec", "NP2000")
probe = probe.get_slice(np.arange(128))
elif probe_name == "Neuronexus-32":
# this probe was not existing really, it was a 'generic' 32 channels
# lets keep as before
probe = generate_multi_columns_probe(
num_columns=3,
num_contact_per_column=[10, 12, 10],
xpitch=30,
ypitch=30,
y_shift_per_column=[0, -15, 0],
contact_shapes="circle",
contact_shape_params={"radius": 8},
)
elif probe_name == "tetrode":
probe = generate_tetrode()
elif probe_name == "sinaps-128":
probe = get_probe("sinaps-research-platform", "p1024s1NHP")
order = np.argsort(probe.contact_ids.astype("int64"))
probe = probe.get_slice(order[:128])
elif "/" in probe_name:
manufacturer, probe_name_ = probe_name.split("/")
probe = get_probe(manufacturer, probe_name_)
elif "#" in probe_name:
manufacturer, probe_name_ = probe_name.split("#")
probe = get_probe(manufacturer, probe_name_)
else:
raise ValueError("wring probe_name")

return probe


def make_one_displacement_vector(
Expand Down Expand Up @@ -449,10 +453,11 @@ def generate_drifting_recording(

# probe
if probe is None:
if generate_probe_kwargs is None:
generate_probe_kwargs = _toy_probes[probe_name]

probe = generate_multi_columns_probe(**generate_probe_kwargs)
if generate_probe_kwargs is not None:
probe = generate_multi_columns_probe(**generate_probe_kwargs)
else:
probe = _make_probe_by_name(probe_name)
# the wiring do not matter because the traces are generated after using the channel locations
num_channels = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(num_channels))

Expand Down
21 changes: 19 additions & 2 deletions src/spikeinterface/generation/tests/test_drifing_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ def test_generate_drifting_recording():
seed=2205,
)

static_recording, drifting_recording, sorting = generate_drifting_recording(
num_units=10,
probe_name="cambridgeneurotech/ASSY-1-E-1",
seed=2205,
)
static_recording, drifting_recording, sorting = generate_drifting_recording(
num_units=10,
probe_name="cambridgeneurotech#ASSY-1-E-1",
seed=2205,
)

static_recording, drifting_recording, sorting = generate_drifting_recording(
num_units=10,
probe_name="Neuropixels2-128",
seed=2205,
)

# print(static_recording)
# print(drifting_recording)
# print(sorting)
Expand All @@ -104,6 +121,6 @@ def test_generate_drifting_recording():

if __name__ == "__main__":
# test_make_one_displacement_vector()
test_generate_displacement_vector()
# test_generate_displacement_vector()
# test_generate_noise()
# test_generate_drifting_recording()
test_generate_drifting_recording()
Loading