diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 7c388713d7..5529c56a4f 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -10,7 +10,7 @@ import numpy as np -from probeinterface import generate_multi_columns_probe +from probeinterface import generate_multi_columns_probe, get_probe, generate_tetrode from spikeinterface import Templates from spikeinterface.core import ms_to_samples @@ -24,52 +24,56 @@ from .drift_tools import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording from .noise_tools import generate_noise -# this should be moved in probeinterface but later -_toy_probes = { - "Neuropixels1-384": dict( - num_columns=4, - num_contact_per_column=[96] * 4, - xpitch=16, - ypitch=40, - y_shift_per_column=[20, 0, 20, 0], - contact_shapes="square", - contact_shape_params={"width": 12}, - ), - "Neuropixels2-384": dict( - num_columns=2, - num_contact_per_column=[192] * 2, - xpitch=32, - ypitch=15, - contact_shapes="square", - contact_shape_params={"width": 12}, - ), - "Neuropixels2-128": dict( - num_columns=2, - num_contact_per_column=[64] * 2, - xpitch=32, - ypitch=15, - contact_shapes="square", - contact_shape_params={"width": 12}, - ), - "Neuropixels1-128": dict( - num_columns=4, - num_contact_per_column=[32] * 4, - xpitch=16, - ypitch=40, - y_shift_per_column=[20, 0, 20, 0], - contact_shapes="square", - contact_shape_params={"width": 12}, - ), - "Neuronexus-32": dict( - num_columns=3, - num_contact_per_column=[10, 12, 10], - xpitch=30, - ypitch=30, - y_shift_per_column=[0, -15, 0], - contact_shapes="circle", - contact_shape_params={"radius": 8}, - ), -} + +def _make_probe_by_name(probe_name: str): + """ + Generates a probe from probeinterface library, using the manufacturer name combined with the probe name, e.g. + - 'cambridgeneurotech/ASSY-37-H7b' + - 'cambridgeneurotech#ASSY-37-H7b' + - 'imec#NP1000" + + This function replace the old `_toy_probes` dict that generate probe using `generate_multi_columns_probe()` + """ + if probe_name == "Neuropixels1-384": + probe = get_probe("imec", "NP1000") + probe = probe.get_slice(np.arange(384)) + elif probe_name == "Neuropixels1-128": + probe = get_probe("imec", "NP1000") + probe = probe.get_slice(np.arange(128)) + elif probe_name == "Neuropixels2-384": + probe = get_probe("imec", "NP2000") + probe = probe.get_slice(np.arange(384)) + elif probe_name == "Neuropixels2-128": + probe = get_probe("imec", "NP2000") + probe = probe.get_slice(np.arange(128)) + elif probe_name == "Neuronexus-32": + # this probe was not existing really, it was a 'generic' 32 channels + # lets keep as before + probe = generate_multi_columns_probe( + num_columns=3, + num_contact_per_column=[10, 12, 10], + xpitch=30, + ypitch=30, + y_shift_per_column=[0, -15, 0], + contact_shapes="circle", + contact_shape_params={"radius": 8}, + ) + elif probe_name == "tetrode": + probe = generate_tetrode() + elif probe_name == "sinaps-128": + probe = get_probe("sinaps-research-platform", "p1024s1NHP") + order = np.argsort(probe.contact_ids.astype("int64")) + probe = probe.get_slice(order[:128]) + elif "/" in probe_name: + manufacturer, probe_name_ = probe_name.split("/") + probe = get_probe(manufacturer, probe_name_) + elif "#" in probe_name: + manufacturer, probe_name_ = probe_name.split("#") + probe = get_probe(manufacturer, probe_name_) + else: + raise ValueError("wring probe_name") + + return probe def make_one_displacement_vector( @@ -449,10 +453,11 @@ def generate_drifting_recording( # probe if probe is None: - if generate_probe_kwargs is None: - generate_probe_kwargs = _toy_probes[probe_name] - - probe = generate_multi_columns_probe(**generate_probe_kwargs) + if generate_probe_kwargs is not None: + probe = generate_multi_columns_probe(**generate_probe_kwargs) + else: + probe = _make_probe_by_name(probe_name) + # the wiring do not matter because the traces are generated after using the channel locations num_channels = probe.get_contact_count() probe.set_device_channel_indices(np.arange(num_channels)) diff --git a/src/spikeinterface/generation/tests/test_drifing_generator.py b/src/spikeinterface/generation/tests/test_drifing_generator.py index a7f6040ec2..5eb49a04c6 100644 --- a/src/spikeinterface/generation/tests/test_drifing_generator.py +++ b/src/spikeinterface/generation/tests/test_drifing_generator.py @@ -95,6 +95,23 @@ def test_generate_drifting_recording(): seed=2205, ) + static_recording, drifting_recording, sorting = generate_drifting_recording( + num_units=10, + probe_name="cambridgeneurotech/ASSY-1-E-1", + seed=2205, + ) + static_recording, drifting_recording, sorting = generate_drifting_recording( + num_units=10, + probe_name="cambridgeneurotech#ASSY-1-E-1", + seed=2205, + ) + + static_recording, drifting_recording, sorting = generate_drifting_recording( + num_units=10, + probe_name="Neuropixels2-128", + seed=2205, + ) + # print(static_recording) # print(drifting_recording) # print(sorting) @@ -104,6 +121,6 @@ def test_generate_drifting_recording(): if __name__ == "__main__": # test_make_one_displacement_vector() - test_generate_displacement_vector() + # test_generate_displacement_vector() # test_generate_noise() - # test_generate_drifting_recording() + test_generate_drifting_recording()