Skip to content
57 changes: 47 additions & 10 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2156,6 +2156,21 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um):
return channel_locations


def _generate_multimodal(rng, size, num_modes, lim0, lim1):
bins = np.linspace(lim0, lim1, 10000)
bin_step = bins[1] - bins[0]
prob = np.zeros(bins.size)
mode_step = (lim1 - lim0) / (num_modes + 1)
for i in range(num_modes):
center = mode_step * (i + 1)
sigma = mode_step / 5.0
prob += np.exp(-((bins - center) ** 2) / (2 * sigma**2))
prob /= np.sum(prob)
choices = np.random.choice(np.arange(bins.size), size, p=prob)
values = bins[choices] + rng.uniform(low=-bin_step / 2, high=bin_step / 2, size=size)
return values


def generate_unit_locations(
num_units,
channel_locations,
Expand All @@ -2165,6 +2180,8 @@ def generate_unit_locations(
minimum_distance=20.0,
max_iteration=100,
distance_strict=False,
distribution="uniform",
num_modes=2,
seed=None,
):
"""
Expand Down Expand Up @@ -2205,6 +2222,14 @@ def generate_unit_locations(
If True, the function will raise an exception if a solution meeting the distance
constraint cannot be found within the maximum number of iterations. If False, a warning
will be issued.
distribution : "uniform" | "multimodal", default: "uniform"
How units are spread.
"uniform" is units everywhere
"multimodal" mimic the distribution of units 'by layer' on the 'y' axis (dim=1)
Important note, when using multimodal in conjonction of minimum_distance not None, there is not garanty
of a true multimodal because units that do not respect the distance of move again and are most chance to be in between layers.
num_modes : int, default 2
In case of distribution="multimodal", this is the number of modes (layers)
seed : int or None, optional
Random seed for reproducibility. If None, the seed is not set

Expand All @@ -2221,7 +2246,12 @@ def generate_unit_locations(
minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um

units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)
units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)
if distribution == "uniform":
units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)
elif distribution == "multimodal":
units_locations[:, 1] = _generate_multimodal(rng, num_units, num_modes, minimum_y, maximum_y)
else:
raise ValueError("generate_unit_locations has wrong distribution must be 'uniform' or ")
units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)

if minimum_distance is not None:
Expand All @@ -2242,20 +2272,27 @@ def generate_unit_locations(
renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]

units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)
if distribution == "uniform":
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)

elif distribution == "multimodal":
units_locations[:, 1][renew_inds] = _generate_multimodal(
rng, renew_inds.size, num_modes, minimum_y, maximum_y
)
units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)

else:
solution_found = True
break

if not solution_found:
if distance_strict:
raise ValueError(
f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
"You can use distance_strict=False or reduce minimum distance"
)
else:
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")
if not solution_found:
if distance_strict:
raise ValueError(
f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
"You can use distance_strict=False or reduce minimum distance"
)
else:
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")

return units_locations

Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,8 @@ def test_synthesize_random_firings_length():
# test_generate_recording()
# test_generate_single_fake_waveform()
# test_transformsorting()
test_generate_templates()
test_generate_unit_locations()
# test_generate_templates()
# test_inject_templates()
# test_generate_ground_truth_recording()
# test_generate_sorting_with_spikes_on_borders()
84 changes: 65 additions & 19 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# this should be moved in probeinterface but later
_toy_probes = {
"Neuropixel-384": dict(
"Neuropixel1-384": dict(
Comment thread
samuelgarcia marked this conversation as resolved.
Outdated
num_columns=4,
num_contact_per_column=[96] * 4,
xpitch=16,
Expand All @@ -34,7 +34,15 @@
contact_shapes="square",
contact_shape_params={"width": 12},
),
"Neuropixel-128": dict(
"Neuropixel2-384": dict(
Comment thread
samuelgarcia marked this conversation as resolved.
Outdated
num_columns=2,
num_contact_per_column=[192] * 2,
xpitch=32,
ypitch=15,
contact_shapes="square",
contact_shape_params={"width": 12},
),
"Neuropixel1-128": dict(
num_columns=4,
num_contact_per_column=[32] * 4,
xpitch=16,
Expand Down Expand Up @@ -69,6 +77,8 @@ def make_one_displacement_vector(
):
"""
Generates a toy displacement vector with ziagzag or bumps patterns.
This displacement vector has no amplitde, this generate only the shape
in the range [-0.5, 0.5]

Parameters
----------
Expand Down Expand Up @@ -141,8 +151,19 @@ def make_one_displacement_vector(
else:
displacement_vector[ind0:ind1] = -0.5

elif drift_mode == "random_walk":
rg = np.random.RandomState(seed=seed)
steps = rg.random_integers(low=0, high=1, size=num_samples)
steps = steps.astype("float64")
# 0 -> -1 and 1 -> 1
steps = steps * 2 - 1
steps[:start_drift_index] = 0
steps[end_drift_index:] = 0
displacement_vector = np.cumsum(steps, dtype="float64")
displacement_vector /= np.max(np.abs(displacement_vector)) * 2

else:
raise ValueError("drift_mode must be 'zigzag' or 'bump'")
raise ValueError("drift_mode must be 'zigzag' or 'bump' or 'random_walk'")

return displacement_vector * amplitude_factor

Expand All @@ -151,8 +172,8 @@ def generate_displacement_vector(
duration,
unit_locations,
displacement_sampling_frequency=5.0,
drift_start_um=[0, 20.0],
drift_stop_um=[0, -20.0],
drift_start_um=[0, 30.0],
drift_stop_um=[0, -30.0],
drift_step_um=1,
motion_list=[
dict(
Expand Down Expand Up @@ -199,6 +220,8 @@ def generate_displacement_vector(

Returns
-------
unit_displacements : numpy.ndarray
The final per unit, displacement vector with shape (num_times, num_units, 2)
displacement_vectors : numpy.ndarray
The drift vector is a numpy array with shape (num_times, 2, num_motions)
num_motions is generally 1, but can be > 1 in case of combining several drift vectors
Expand Down Expand Up @@ -234,7 +257,9 @@ def generate_displacement_vector(
**motion_kwargs,
seed=seed,
)

one_displacement = one_displacement[:, np.newaxis] * (drift_stop_um - drift_start_um) + mid

displacement_vectors.append(one_displacement[:, :, np.newaxis])

if non_rigid_gradient is None:
Expand All @@ -253,14 +278,36 @@ def generate_displacement_vector(

displacement_vectors = np.concatenate(displacement_vectors, axis=2)

return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps
# unit_displacements is the sum of all discplacements (times, units, direction_x_y)
unit_displacements = np.zeros((displacement_vectors.shape[0], num_units, 2))
for direction in (0, 1):
# x and y
for i in range(displacement_vectors.shape[2]):
m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :]
unit_displacements[:, :, direction] += m

lim0 = min(drift_start_um[direction], drift_stop_um[direction])
lim1 = max(drift_start_um[direction], drift_stop_um[direction])
if np.min(unit_displacements[:, :, direction]) < lim0 or np.max(unit_displacements[:, :, direction]) > lim1:
raise ValueError(
"unit_displacements is too big when combining several motion (with motion_list)."
"Please consider a smaller 'amplitude_factor' for each motion"
)

return (
unit_displacements,
displacement_vectors,
displacement_unit_factor,
displacement_sampling_frequency,
displacements_steps,
)


def generate_drifting_recording(
num_units=250,
duration=600.0,
sampling_frequency=30000.0,
probe_name="Neuropixel-128",
probe_name="Neuropixel1-128",
generate_probe_kwargs=None,
generate_unit_locations_kwargs=dict(
margin_um=20.0,
Expand All @@ -269,6 +316,9 @@ def generate_drifting_recording(
minimum_distance=18.0,
max_iteration=100,
distance_strict=False,
distribution="uniform",
# distribution="multimodal",
# num_modes=2,
),
generate_displacement_vector_kwargs=dict(
displacement_sampling_frequency=5.0,
Expand Down Expand Up @@ -311,7 +361,7 @@ def generate_drifting_recording(
The duration in seconds.
sampling_frequency : float, dfault: 30000.
The sampling frequency.
probe_name : str, default: "Neuropixel-128"
probe_name : str, default: "Neuropixel1-128"
The probe type if generate_probe_kwargs is None.
generate_probe_kwargs : None or dict
A dict to generate the probe, this supersede probe_name when not None.
Expand Down Expand Up @@ -371,17 +421,13 @@ def generate_drifting_recording(
**generate_unit_locations_kwargs,
)

displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps = (
generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs)
)

# unit_displacements is the sum of all discplacements (times, units, direction_x_y)
unit_displacements = np.zeros((displacement_vectors.shape[0], num_units, 2))
for direction in (0, 1):
# x and y
for i in range(displacement_vectors.shape[2]):
m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :]
unit_displacements[:, :, direction] += m
(
unit_displacements,
displacement_vectors,
displacement_unit_factor,
displacement_sampling_frequency,
displacements_steps,
) = generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs)

# unit_params need to be fixed before the displacement steps
generate_templates_kwargs = generate_templates_kwargs.copy()
Expand Down
73 changes: 43 additions & 30 deletions src/spikeinterface/generation/tests/test_drifing_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def test_make_one_displacement_vector():
drift_mode="bump", duration=700.0, period_s=300, bump_interval_s=(30, 90.0), t_start_drift=100.0, seed=2205
)

displacement_vector = make_one_displacement_vector(
drift_mode="random_walk", duration=700.0, t_start_drift=100.0, seed=2205
)

# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# ax.plot(displacement_vector)
Expand All @@ -29,39 +33,48 @@ def test_generate_displacement_vector():
unit_locations[:, 1] = np.linspace(-50, 50, 10)

# one motion Y only
displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps = (
generate_displacement_vector(duration, unit_locations)
)
(
unit_displacements,
displacement_vectors,
displacement_unit_factor,
displacement_sampling_frequency,
displacements_steps,
) = generate_displacement_vector(duration, unit_locations)
assert unit_locations.shape[0] == unit_displacements.shape[1]
assert unit_locations.shape[0] == displacement_unit_factor.shape[0]
assert displacement_vectors.shape[2] == displacement_unit_factor.shape[1]
assert displacement_vectors.shape[2] == 1

# two motion X and Y
displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps = (
generate_displacement_vector(
duration,
unit_locations,
drift_start_um=[-5, 20.0],
drift_stop_um=[5, -20.0],
motion_list=[
dict(
drift_mode="zigzag",
amplitude_factor=1.0,
non_rigid_gradient=0.4,
t_start_drift=60.0,
t_end_drift=None,
period_s=200,
),
dict(
drift_mode="bump",
amplitude_factor=0.3,
non_rigid_gradient=0.4,
t_start_drift=60.0,
t_end_drift=None,
bump_interval_s=(30, 90.0),
),
],
)
(
unit_displacements,
displacement_vectors,
displacement_unit_factor,
displacement_sampling_frequency,
displacements_steps,
) = generate_displacement_vector(
duration,
unit_locations,
drift_start_um=[-5, 20.0],
drift_stop_um=[5, -20.0],
motion_list=[
dict(
drift_mode="zigzag",
amplitude_factor=0.7,
non_rigid_gradient=0.4,
t_start_drift=60.0,
t_end_drift=None,
period_s=200,
),
dict(
drift_mode="bump",
amplitude_factor=0.3,
non_rigid_gradient=0.4,
t_start_drift=60.0,
t_end_drift=None,
bump_interval_s=(30, 90.0),
),
],
)
assert unit_locations.shape[0] == displacement_unit_factor.shape[0]
assert displacement_vectors.shape[2] == displacement_unit_factor.shape[1]
Expand Down Expand Up @@ -91,6 +104,6 @@ def test_generate_drifting_recording():

if __name__ == "__main__":
# test_make_one_displacement_vector()
# test_generate_displacement_vector()
test_generate_displacement_vector()
# test_generate_noise()
test_generate_drifting_recording()
# test_generate_drifting_recording()