Skip to content

Commit 9d72b61

Browse files
committed
amplitude jitter in drifting_generator; share logic with hybrid_tools
1 parent eecaffe commit 9d72b61

3 files changed

Lines changed: 41 additions & 6 deletions

File tree

src/spikeinterface/core/generate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,22 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol
10721072
return spike_train
10731073

10741074

1075+
def synthesize_amplitude_factor(
1076+
num_spikes: int,
1077+
amplitude_factor: np.ndarray | None = None,
1078+
amplitude_std: float | None = None,
1079+
seed: np.random.Generator | int | None = None,
1080+
):
1081+
if amplitude_factor is not None:
1082+
assert amplitude_factor.shape == (num_spikes,)
1083+
return amplitude_factor
1084+
elif amplitude_std:
1085+
rng = np.random.default_rng(seed)
1086+
return rng.normal(loc=1, scale=amplitude_std, size=num_spikes)
1087+
else:
1088+
return None
1089+
1090+
10751091
from spikeinterface.core.basesorting import BaseSortingSegment, BaseSorting
10761092

10771093

src/spikeinterface/generation/drifting_generator.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
generate_unit_locations,
1919
generate_sorting,
2020
generate_templates,
21+
synthesize_amplitude_factor,
2122
_ensure_unit_params,
2223
_ensure_seed,
2324
)
@@ -366,6 +367,8 @@ def generate_drifting_recording(
366367
generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0),
367368
noise=None,
368369
generate_noise_kwargs=dict(noise_levels=(6.0, 8.0), spatial_decay=25.0),
370+
amplitude_std: float | None = None,
371+
amplitude_factor: np.ndarray | None = None,
369372
extra_outputs=False,
370373
seed=None,
371374
):
@@ -405,6 +408,11 @@ def generate_drifting_recording(
405408
Noise generator used to generate background noise
406409
generate_noise_kwargs : dict
407410
Parameters given to generate_noise() if no noise is None
411+
amplitude_std : float, default: 0.05
412+
The standard deviation of the modulation to apply to the spikes when injecting them
413+
into the recording.
414+
amplitude_factor: np.ndarray, optional
415+
Optional fixed per-spike amplitude modulation
408416
extra_outputs : bool, default False
409417
Return optionaly a dict with more variables.
410418
seed : None ot int
@@ -559,6 +567,13 @@ def generate_drifting_recording(
559567
assert noise.probe.get_contact_count() == probe.get_contact_count(), "Noise num channels mismatch"
560568
assert noise.get_total_duration() == duration, "Noise duration should be the same as the recording duration"
561569

570+
amplitude_factor = synthesize_amplitude_factor(
571+
num_spikes=sorting.count_total_num_spikes(),
572+
amplitude_factor=amplitude_factor,
573+
amplitude_std=amplitude_std,
574+
seed=seed,
575+
)
576+
562577
static_recording = InjectDriftingTemplatesRecording(
563578
sorting=sorting,
564579
parent_recording=noise,
@@ -567,7 +582,7 @@ def generate_drifting_recording(
567582
displacement_sampling_frequency=displacement_sampling_frequency,
568583
displacement_unit_factor=np.zeros_like(displacement_unit_factor),
569584
num_samples=[int(duration * sampling_frequency)],
570-
amplitude_factor=None,
585+
amplitude_factor=amplitude_factor,
571586
)
572587

573588
drifting_recording = InjectDriftingTemplatesRecording(
@@ -578,7 +593,7 @@ def generate_drifting_recording(
578593
displacement_sampling_frequency=displacement_sampling_frequency,
579594
displacement_unit_factor=displacement_unit_factor,
580595
num_samples=[int(duration * sampling_frequency)],
581-
amplitude_factor=None,
596+
amplitude_factor=amplitude_factor,
582597
)
583598

584599
if extra_outputs:

src/spikeinterface/generation/hybrid_tools.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
generate_sorting,
1111
InjectTemplatesRecording,
1212
_ensure_seed,
13+
synthesize_amplitude_factor,
1314
)
1415
from spikeinterface.core.template_tools import get_template_extremum_channel
1516

@@ -327,6 +328,7 @@ def generate_hybrid_recording(
327328
upsample_factor: int | None = None,
328329
upsample_vector: np.ndarray | None = None,
329330
amplitude_std: float = 0.05,
331+
amplitude_factor: np.ndarray | None = None,
330332
generate_sorting_kwargs: dict = dict(num_units=10, firing_rates=15, refractory_period_ms=4.0, seed=2205),
331333
generate_unit_locations_kwargs: dict = dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
332334
generate_templates_kwargs: dict = dict(ms_before=1.0, ms_after=3.0),
@@ -499,10 +501,12 @@ def generate_hybrid_recording(
499501
upsample_factor = templates_array.shape[3]
500502
upsample_vector = rng.integers(0, upsample_factor, size=num_spikes)
501503

502-
if amplitude_std is not None:
503-
amplitude_factor = rng.normal(loc=1, scale=amplitude_std, size=num_spikes)
504-
else:
505-
amplitude_factor = None
504+
amplitude_factor = synthesize_amplitude_factor(
505+
num_spikes,
506+
amplitude_factor=amplitude_factor,
507+
amplitude_std=amplitude_std,
508+
seed=rng,
509+
)
506510

507511
if motion is not None:
508512
assert num_segments == motion.num_segments, "recording and motion should have the same number of segments"

0 commit comments

Comments
 (0)