Skip to content

Commit 73f94a9

Browse files
Merge pull request #69 from AllenNeuralDynamics/avg-signal-window
added get_average_signal_window.
2 parents 0c1686c + f5d31c4 commit 73f94a9

1 file changed

Lines changed: 155 additions & 0 deletions

File tree

src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
import aind_dynamic_foraging_data_utils.nwb_utils as nu
99
import aind_dynamic_foraging_models.logistic_regression.model as model
10+
from aind_dynamic_foraging_data_utils import alignment as an
1011
import numpy as np
1112
import pandas as pd
13+
import warnings
1214

1315
import aind_dynamic_foraging_basic_analysis.licks.annotation as a
1416

@@ -234,3 +236,156 @@ def add_intertrial_licking(df_trials, df_licks):
234236
df_trials["intertrial_choice"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean()
235237
)
236238
return df_trials
239+
240+
241+
def get_average_signal_window_multi(
242+
nwbs,
243+
alignment_event,
244+
offsets,
245+
channel,
246+
data_column='data_z',
247+
censor=True,
248+
output_col=None
249+
):
250+
"""
251+
Wrapper for get_average_signal_window to process a
252+
list of nwb objects and concatenate the results.
253+
254+
Parameters
255+
----------
256+
nwbs : list
257+
List of nwb-like objects (each with .df_trials and .df_fip).
258+
alignment_event : str
259+
The event column in df_trials to align to.
260+
offsets : list or tuple of float
261+
[start, end] offsets (in seconds) relative to alignment_event.
262+
channel : str
263+
The value in df_fip['event'] to filter for.
264+
data_col : str
265+
Column in df_fip to extract (default 'data_z').
266+
censor, censor important timepoints before and after aligned timepoints
267+
output_col : str or None
268+
Name for the new column. If None, will be generated automatically.
269+
270+
Returns
271+
-------
272+
pd.DataFrame
273+
Concatenated DataFrame of all trials with the new signal window column.
274+
"""
275+
all_trials_avg_signal = []
276+
for nwb in nwbs:
277+
df_trials = get_average_signal_window(
278+
nwb,
279+
alignment_event=alignment_event,
280+
offsets=offsets,
281+
channel=channel,
282+
data_column=data_column,
283+
censor=censor,
284+
output_col=output_col
285+
)
286+
cols_needed = ['trial', 'ses_idx', df_trials.columns[-1]]
287+
all_trials_avg_signal.append(df_trials[cols_needed])
288+
return pd.concat(all_trials_avg_signal, ignore_index=True)
289+
290+
291+
def get_average_signal_window(
292+
nwb,
293+
alignment_event,
294+
offsets,
295+
channel,
296+
data_column='data_z',
297+
censor=True,
298+
output_col=None,
299+
):
300+
"""
301+
Returns a Series with the mean signal in a window around an alignment event,
302+
for each trial, for each session and a specific signal (event).
303+
304+
Parameters
305+
----------
306+
nwb : nwb object (or nwb-like object)
307+
nwb object with df_fip and df_trials attributes
308+
alignment_event : str
309+
The event column in df_trials to align to. must be given in_session, not in_trial
310+
offsets: list or tuple of float
311+
[start, end] offsets (in seconds) relative to alignment_event.
312+
channel : str
313+
The value in df_fip['event'] to filter for.
314+
data_column : str
315+
Column in df_fip to extract (default 'data_z').
316+
censor, censor important timepoints before and after aligned timepoints
317+
output_col : str or None
318+
Name for the new column. If None, will be generated as
319+
'<data_col>_<channel>_<start>_<end>_<alignment_event>'.
320+
321+
322+
Returns
323+
-------
324+
df_trial: pd.DataFrame
325+
DataFrame with a new column containing the mean signal
326+
in the specified window for each trial.
327+
328+
EXAMPLE
329+
*******************
330+
df_trials = get_average_signal_window(nwb, alignment_event='choice_time_in_session',
331+
offsets=[0.33,1],channel='G_0_dff-bright_mc-iso-IRLS',
332+
data_column='data_z_norm')
333+
"""
334+
335+
# Check alignment_event ends with 'in_session'
336+
if not alignment_event.endswith('in_session'):
337+
raise ValueError(f"alignment_event '{alignment_event}' must end with 'in_session'.")
338+
339+
if not hasattr(nwb, "df_trials"):
340+
raise ValueError("You need to compute df_trials: nwb_utils.create_trials_df(nwb)")
341+
342+
if not hasattr(nwb, "df_fip"):
343+
raise ValueError("You need to compute df_fip: nwb_utils.create_fib_df(nwb)")
344+
345+
# Check alignment_event is in df_trials columns
346+
if alignment_event not in nwb.df_trials.columns:
347+
raise ValueError(f"alignment_event '{alignment_event}' not found in df_trials columns.")
348+
349+
if channel not in nwb.df_fip.event.unique():
350+
warnings.warn(f"{channel} channel not found in df_fip. Returning original df_trials.")
351+
return nwb.df_trials
352+
353+
if data_column not in nwb.df_fip.columns:
354+
raise ValueError(f"data column '{data_column}' not found in df_trials columns.")
355+
356+
# Get output column name
357+
if output_col is None:
358+
output_col = (
359+
f"{data_column}_{channel}_{offsets[0]}_"
360+
f"{offsets[1]}_{alignment_event.replace('_in_session','')}"
361+
)
362+
363+
# copy df_trials, drops na values, sort trial by alignment event
364+
# sorting needed because censor in event_triggered_response sorts
365+
# this allows the trials to be matched with event_times
366+
df_trials = nwb.df_trials.dropna(subset=alignment_event, inplace=False)
367+
df_trials = df_trials.sort_values(by=alignment_event)
368+
369+
data = nwb.df_fip.query("event == @channel")
370+
align_timepoints = df_trials[alignment_event].values
371+
372+
etr = an.event_triggered_response(
373+
data,
374+
"timestamps",
375+
data_column,
376+
align_timepoints,
377+
t_start=offsets[0],
378+
t_end=offsets[1],
379+
output_sampling_rate=40,
380+
censor=censor,
381+
censor_times=None,
382+
)
383+
384+
avg_activity = etr.groupby("event_number").mean()
385+
avg_activity['trial'] = df_trials.trial.values
386+
avg_activity = avg_activity.rename(columns={data_column: output_col})
387+
388+
# Merge on 'trial'
389+
df_trials = df_trials.merge(avg_activity[['trial', output_col]], on='trial', how='left')
390+
391+
return df_trials

0 commit comments

Comments
 (0)