Skip to content

Commit 2ed7897

Browse files
authored
Merge pull request #3973 from samuelgarcia/improve_generators
Improve generators
2 parents 78ede06 + 6f2f60c commit 2ed7897

5 files changed

Lines changed: 158 additions & 61 deletions

File tree

src/spikeinterface/core/generate.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2156,6 +2156,21 @@ def generate_channel_locations(num_channels, num_columns, contact_spacing_um):
21562156
return channel_locations
21572157

21582158

2159+
def _generate_multimodal(rng, size, num_modes, lim0, lim1):
2160+
bins = np.linspace(lim0, lim1, 10000)
2161+
bin_step = bins[1] - bins[0]
2162+
prob = np.zeros(bins.size)
2163+
mode_step = (lim1 - lim0) / (num_modes + 1)
2164+
for i in range(num_modes):
2165+
center = mode_step * (i + 1)
2166+
sigma = mode_step / 5.0
2167+
prob += np.exp(-((bins - center) ** 2) / (2 * sigma**2))
2168+
prob /= np.sum(prob)
2169+
choices = np.random.choice(np.arange(bins.size), size, p=prob)
2170+
values = bins[choices] + rng.uniform(low=-bin_step / 2, high=bin_step / 2, size=size)
2171+
return values
2172+
2173+
21592174
def generate_unit_locations(
21602175
num_units,
21612176
channel_locations,
@@ -2165,6 +2180,8 @@ def generate_unit_locations(
21652180
minimum_distance=20.0,
21662181
max_iteration=100,
21672182
distance_strict=False,
2183+
distribution="uniform",
2184+
num_modes=2,
21682185
seed=None,
21692186
):
21702187
"""
@@ -2205,6 +2222,14 @@ def generate_unit_locations(
22052222
If True, the function will raise an exception if a solution meeting the distance
22062223
constraint cannot be found within the maximum number of iterations. If False, a warning
22072224
will be issued.
2225+
distribution : "uniform" | "multimodal", default: "uniform"
2226+
How units are spread.
2227+
"uniform" is units everywhere
2228+
"multimodal" mimic the distribution of units 'by layer' on the 'y' axis (dim=1)
2229+
Important note, when using multimodal in conjonction of minimum_distance not None, there is not garanty
2230+
of a true multimodal because units that do not respect the distance of move again and are most chance to be in between layers.
2231+
num_modes : int, default 2
2232+
In case of distribution="multimodal", this is the number of modes (layers)
22082233
seed : int or None, optional
22092234
Random seed for reproducibility. If None, the seed is not set
22102235
@@ -2221,7 +2246,12 @@ def generate_unit_locations(
22212246
minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um
22222247

22232248
units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)
2224-
units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)
2249+
if distribution == "uniform":
2250+
units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)
2251+
elif distribution == "multimodal":
2252+
units_locations[:, 1] = _generate_multimodal(rng, num_units, num_modes, minimum_y, maximum_y)
2253+
else:
2254+
raise ValueError("generate_unit_locations has wrong distribution must be 'uniform' or ")
22252255
units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)
22262256

22272257
if minimum_distance is not None:
@@ -2242,20 +2272,27 @@ def generate_unit_locations(
22422272
renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]
22432273

22442274
units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)
2245-
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)
2275+
if distribution == "uniform":
2276+
units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)
2277+
2278+
elif distribution == "multimodal":
2279+
units_locations[:, 1][renew_inds] = _generate_multimodal(
2280+
rng, renew_inds.size, num_modes, minimum_y, maximum_y
2281+
)
22462282
units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)
2283+
22472284
else:
22482285
solution_found = True
22492286
break
22502287

2251-
if not solution_found:
2252-
if distance_strict:
2253-
raise ValueError(
2254-
f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
2255-
"You can use distance_strict=False or reduce minimum distance"
2256-
)
2257-
else:
2258-
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")
2288+
if not solution_found:
2289+
if distance_strict:
2290+
raise ValueError(
2291+
f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} "
2292+
"You can use distance_strict=False or reduce minimum distance"
2293+
)
2294+
else:
2295+
warnings.warn(f"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}")
22592296

22602297
return units_locations
22612298

src/spikeinterface/core/tests/test_generate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,8 @@ def test_synthesize_random_firings_length():
654654
# test_generate_recording()
655655
# test_generate_single_fake_waveform()
656656
# test_transformsorting()
657-
test_generate_templates()
657+
test_generate_unit_locations()
658+
# test_generate_templates()
658659
# test_inject_templates()
659660
# test_generate_ground_truth_recording()
660661
# test_generate_sorting_with_spikes_on_borders()

src/spikeinterface/core/tests/test_job_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def test_get_best_job_kwargs():
301301
# num_units=50,
302302
# duration=120.0,
303303
# sampling_frequency=30000.0,
304-
# probe_name="Neuropixel-128",
304+
# probe_name="Neuropixels-128",
305305

306306
# )
307307
# # print(rec)

src/spikeinterface/generation/drifting_generator.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
# this should be moved in probeinterface but later
2727
_toy_probes = {
28-
"Neuropixel-384": dict(
28+
"Neuropixels1-384": dict(
2929
num_columns=4,
3030
num_contact_per_column=[96] * 4,
3131
xpitch=16,
@@ -34,7 +34,15 @@
3434
contact_shapes="square",
3535
contact_shape_params={"width": 12},
3636
),
37-
"Neuropixel-128": dict(
37+
"Neuropixels2-384": dict(
38+
num_columns=2,
39+
num_contact_per_column=[192] * 2,
40+
xpitch=32,
41+
ypitch=15,
42+
contact_shapes="square",
43+
contact_shape_params={"width": 12},
44+
),
45+
"Neuropixels1-128": dict(
3846
num_columns=4,
3947
num_contact_per_column=[32] * 4,
4048
xpitch=16,
@@ -69,6 +77,8 @@ def make_one_displacement_vector(
6977
):
7078
"""
7179
Generates a toy displacement vector with ziagzag or bumps patterns.
80+
This displacement vector has no amplitde, this generate only the shape
81+
in the range [-0.5, 0.5]
7282
7383
Parameters
7484
----------
@@ -141,8 +151,19 @@ def make_one_displacement_vector(
141151
else:
142152
displacement_vector[ind0:ind1] = -0.5
143153

154+
elif drift_mode == "random_walk":
155+
rg = np.random.RandomState(seed=seed)
156+
steps = rg.random_integers(low=0, high=1, size=num_samples)
157+
steps = steps.astype("float64")
158+
# 0 -> -1 and 1 -> 1
159+
steps = steps * 2 - 1
160+
steps[:start_drift_index] = 0
161+
steps[end_drift_index:] = 0
162+
displacement_vector = np.cumsum(steps, dtype="float64")
163+
displacement_vector /= np.max(np.abs(displacement_vector)) * 2
164+
144165
else:
145-
raise ValueError("drift_mode must be 'zigzag' or 'bump'")
166+
raise ValueError("drift_mode must be 'zigzag' or 'bump' or 'random_walk'")
146167

147168
return displacement_vector * amplitude_factor
148169

@@ -151,8 +172,8 @@ def generate_displacement_vector(
151172
duration,
152173
unit_locations,
153174
displacement_sampling_frequency=5.0,
154-
drift_start_um=[0, 20.0],
155-
drift_stop_um=[0, -20.0],
175+
drift_start_um=[0, 30.0],
176+
drift_stop_um=[0, -30.0],
156177
drift_step_um=1,
157178
motion_list=[
158179
dict(
@@ -199,6 +220,8 @@ def generate_displacement_vector(
199220
200221
Returns
201222
-------
223+
unit_displacements : numpy.ndarray
224+
The final per unit, displacement vector with shape (num_times, num_units, 2)
202225
displacement_vectors : numpy.ndarray
203226
The drift vector is a numpy array with shape (num_times, 2, num_motions)
204227
num_motions is generally 1, but can be > 1 in case of combining several drift vectors
@@ -234,7 +257,9 @@ def generate_displacement_vector(
234257
**motion_kwargs,
235258
seed=seed,
236259
)
260+
237261
one_displacement = one_displacement[:, np.newaxis] * (drift_stop_um - drift_start_um) + mid
262+
238263
displacement_vectors.append(one_displacement[:, :, np.newaxis])
239264

240265
if non_rigid_gradient is None:
@@ -253,14 +278,36 @@ def generate_displacement_vector(
253278

254279
displacement_vectors = np.concatenate(displacement_vectors, axis=2)
255280

256-
return displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps
281+
# unit_displacements is the sum of all discplacements (times, units, direction_x_y)
282+
unit_displacements = np.zeros((displacement_vectors.shape[0], num_units, 2))
283+
for direction in (0, 1):
284+
# x and y
285+
for i in range(displacement_vectors.shape[2]):
286+
m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :]
287+
unit_displacements[:, :, direction] += m
288+
289+
lim0 = min(drift_start_um[direction], drift_stop_um[direction])
290+
lim1 = max(drift_start_um[direction], drift_stop_um[direction])
291+
if np.min(unit_displacements[:, :, direction]) < lim0 or np.max(unit_displacements[:, :, direction]) > lim1:
292+
raise ValueError(
293+
"unit_displacements is too big when combining several motion (with motion_list)."
294+
"Please consider a smaller 'amplitude_factor' for each motion"
295+
)
296+
297+
return (
298+
unit_displacements,
299+
displacement_vectors,
300+
displacement_unit_factor,
301+
displacement_sampling_frequency,
302+
displacements_steps,
303+
)
257304

258305

259306
def generate_drifting_recording(
260307
num_units=250,
261308
duration=600.0,
262309
sampling_frequency=30000.0,
263-
probe_name="Neuropixel-128",
310+
probe_name="Neuropixels1-128",
264311
generate_probe_kwargs=None,
265312
generate_unit_locations_kwargs=dict(
266313
margin_um=20.0,
@@ -269,6 +316,9 @@ def generate_drifting_recording(
269316
minimum_distance=18.0,
270317
max_iteration=100,
271318
distance_strict=False,
319+
distribution="uniform",
320+
# distribution="multimodal",
321+
# num_modes=2,
272322
),
273323
generate_displacement_vector_kwargs=dict(
274324
displacement_sampling_frequency=5.0,
@@ -311,7 +361,7 @@ def generate_drifting_recording(
311361
The duration in seconds.
312362
sampling_frequency : float, dfault: 30000.
313363
The sampling frequency.
314-
probe_name : str, default: "Neuropixel-128"
364+
probe_name : str, default: "Neuropixels1-128"
315365
The probe type if generate_probe_kwargs is None.
316366
generate_probe_kwargs : None or dict
317367
A dict to generate the probe, this supersede probe_name when not None.
@@ -371,17 +421,13 @@ def generate_drifting_recording(
371421
**generate_unit_locations_kwargs,
372422
)
373423

374-
displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps = (
375-
generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs)
376-
)
377-
378-
# unit_displacements is the sum of all discplacements (times, units, direction_x_y)
379-
unit_displacements = np.zeros((displacement_vectors.shape[0], num_units, 2))
380-
for direction in (0, 1):
381-
# x and y
382-
for i in range(displacement_vectors.shape[2]):
383-
m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :]
384-
unit_displacements[:, :, direction] += m
424+
(
425+
unit_displacements,
426+
displacement_vectors,
427+
displacement_unit_factor,
428+
displacement_sampling_frequency,
429+
displacements_steps,
430+
) = generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs)
385431

386432
# unit_params need to be fixed before the displacement steps
387433
generate_templates_kwargs = generate_templates_kwargs.copy()

src/spikeinterface/generation/tests/test_drifing_generator.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ def test_make_one_displacement_vector():
1717
drift_mode="bump", duration=700.0, period_s=300, bump_interval_s=(30, 90.0), t_start_drift=100.0, seed=2205
1818
)
1919

20+
displacement_vector = make_one_displacement_vector(
21+
drift_mode="random_walk", duration=700.0, t_start_drift=100.0, seed=2205
22+
)
23+
2024
# import matplotlib.pyplot as plt
2125
# fig, ax = plt.subplots()
2226
# ax.plot(displacement_vector)
@@ -29,39 +33,48 @@ def test_generate_displacement_vector():
2933
unit_locations[:, 1] = np.linspace(-50, 50, 10)
3034

3135
# one motion Y only
32-
displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps = (
33-
generate_displacement_vector(duration, unit_locations)
34-
)
36+
(
37+
unit_displacements,
38+
displacement_vectors,
39+
displacement_unit_factor,
40+
displacement_sampling_frequency,
41+
displacements_steps,
42+
) = generate_displacement_vector(duration, unit_locations)
43+
assert unit_locations.shape[0] == unit_displacements.shape[1]
3544
assert unit_locations.shape[0] == displacement_unit_factor.shape[0]
3645
assert displacement_vectors.shape[2] == displacement_unit_factor.shape[1]
3746
assert displacement_vectors.shape[2] == 1
3847

3948
# two motion X and Y
40-
displacement_vectors, displacement_unit_factor, displacement_sampling_frequency, displacements_steps = (
41-
generate_displacement_vector(
42-
duration,
43-
unit_locations,
44-
drift_start_um=[-5, 20.0],
45-
drift_stop_um=[5, -20.0],
46-
motion_list=[
47-
dict(
48-
drift_mode="zigzag",
49-
amplitude_factor=1.0,
50-
non_rigid_gradient=0.4,
51-
t_start_drift=60.0,
52-
t_end_drift=None,
53-
period_s=200,
54-
),
55-
dict(
56-
drift_mode="bump",
57-
amplitude_factor=0.3,
58-
non_rigid_gradient=0.4,
59-
t_start_drift=60.0,
60-
t_end_drift=None,
61-
bump_interval_s=(30, 90.0),
62-
),
63-
],
64-
)
49+
(
50+
unit_displacements,
51+
displacement_vectors,
52+
displacement_unit_factor,
53+
displacement_sampling_frequency,
54+
displacements_steps,
55+
) = generate_displacement_vector(
56+
duration,
57+
unit_locations,
58+
drift_start_um=[-5, 20.0],
59+
drift_stop_um=[5, -20.0],
60+
motion_list=[
61+
dict(
62+
drift_mode="zigzag",
63+
amplitude_factor=0.7,
64+
non_rigid_gradient=0.4,
65+
t_start_drift=60.0,
66+
t_end_drift=None,
67+
period_s=200,
68+
),
69+
dict(
70+
drift_mode="bump",
71+
amplitude_factor=0.3,
72+
non_rigid_gradient=0.4,
73+
t_start_drift=60.0,
74+
t_end_drift=None,
75+
bump_interval_s=(30, 90.0),
76+
),
77+
],
6578
)
6679
assert unit_locations.shape[0] == displacement_unit_factor.shape[0]
6780
assert displacement_vectors.shape[2] == displacement_unit_factor.shape[1]
@@ -91,6 +104,6 @@ def test_generate_drifting_recording():
91104

92105
if __name__ == "__main__":
93106
# test_make_one_displacement_vector()
94-
# test_generate_displacement_vector()
107+
test_generate_displacement_vector()
95108
# test_generate_noise()
96-
test_generate_drifting_recording()
109+
# test_generate_drifting_recording()

0 commit comments

Comments
 (0)