1818 generate_unit_locations ,
1919 generate_sorting ,
2020 generate_templates ,
21+ synthesize_amplitude_factor ,
2122 _ensure_unit_params ,
2223 _ensure_seed ,
2324)
@@ -366,6 +367,8 @@ def generate_drifting_recording(
366367 generate_sorting_kwargs = dict (firing_rates = (2.0 , 8.0 ), refractory_period_ms = 4.0 ),
367368 noise = None ,
368369 generate_noise_kwargs = dict (noise_levels = (6.0 , 8.0 ), spatial_decay = 25.0 ),
370+ amplitude_std : float | None = None ,
371+ amplitude_factor : np .ndarray | None = None ,
369372 extra_outputs = False ,
370373 seed = None ,
371374):
@@ -405,6 +408,11 @@ def generate_drifting_recording(
405408 Noise generator used to generate background noise
406409 generate_noise_kwargs : dict
407410 Parameters given to generate_noise() if no noise is None
411+ amplitude_std : float, default: 0.05
412+ The standard deviation of the modulation to apply to the spikes when injecting them
413+ into the recording.
414+ amplitude_factor: np.ndarray, optional
415+ Optional fixed per-spike amplitude modulation
408416 extra_outputs : bool, default False
409417 Return optionaly a dict with more variables.
410418 seed : None ot int
@@ -559,6 +567,13 @@ def generate_drifting_recording(
559567 assert noise .probe .get_contact_count () == probe .get_contact_count (), "Noise num channels mismatch"
560568 assert noise .get_total_duration () == duration , "Noise duration should be the same as the recording duration"
561569
570+ amplitude_factor = synthesize_amplitude_factor (
571+ num_spikes = sorting .count_total_num_spikes (),
572+ amplitude_factor = amplitude_factor ,
573+ amplitude_std = amplitude_std ,
574+ seed = seed ,
575+ )
576+
562577 static_recording = InjectDriftingTemplatesRecording (
563578 sorting = sorting ,
564579 parent_recording = noise ,
@@ -567,7 +582,7 @@ def generate_drifting_recording(
567582 displacement_sampling_frequency = displacement_sampling_frequency ,
568583 displacement_unit_factor = np .zeros_like (displacement_unit_factor ),
569584 num_samples = [int (duration * sampling_frequency )],
570- amplitude_factor = None ,
585+ amplitude_factor = amplitude_factor ,
571586 )
572587
573588 drifting_recording = InjectDriftingTemplatesRecording (
@@ -578,7 +593,7 @@ def generate_drifting_recording(
578593 displacement_sampling_frequency = displacement_sampling_frequency ,
579594 displacement_unit_factor = displacement_unit_factor ,
580595 num_samples = [int (duration * sampling_frequency )],
581- amplitude_factor = None ,
596+ amplitude_factor = amplitude_factor ,
582597 )
583598
584599 if extra_outputs :
0 commit comments