Skip to content

Commit 1bdc40c

Browse files
addressed alex's concerns. linted. need to test.
1 parent 737ac73 commit 1bdc40c

1 file changed

Lines changed: 28 additions & 12 deletions

File tree

src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import aind_dynamic_foraging_data_utils.nwb_utils as nu
99
import aind_dynamic_foraging_models.logistic_regression.model as model
10-
import aind_dynamic_foraging_basic_analysis.plot.plot_fip as pf
10+
from aind_dynamic_foraging_data_utils import alignment as an
1111
import numpy as np
1212
import pandas as pd
1313

@@ -242,7 +242,8 @@ def get_average_signal_window(
242242
alignment_event,
243243
offsets,
244244
channel,
245-
data_col='data_z',
245+
data_column='data_z',
246+
censor=True,
246247
output_col=None,
247248
):
248249
"""
@@ -255,12 +256,13 @@ def get_average_signal_window(
255256
nwb object with df_fip and df_trials attributes
256257
alignment_event : str
257258
The event column in df_trials to align to. must be given in_session, not in_trial
258-
offsets : list or tuple of float
259+
offsets: list or tuple of float
259260
[start, end] offsets (in seconds) relative to alignment_event.
260261
channel : str
261262
The value in df_fip['event'] to filter for.
262-
data_col : str
263+
data_column : str
263264
Column in df_fip to extract (default 'data_z').
265+
censor, censor important timepoints before and after aligned timepoints
264266
output_col : str or None
265267
Name for the new column. If None, will be generated as
266268
'<data_col>_<channel>_<start>_<end>_<alignment_event>'.
@@ -287,20 +289,34 @@ def get_average_signal_window(
287289
# Get output column name
288290
if output_col is None:
289291
output_col = (
290-
f"{data_col}_{channel}_{offsets[0]}_"
292+
f"{data_column}_{channel}_{offsets[0]}_"
291293
f"{offsets[1]}_{alignment_event.replace('_in_session','')}"
292294
)
293295

294-
df_trials = nwb.df_trials.copy()
295-
296-
# get event triggered response. Censor set to FALSE because event_times should match trial #
297-
etr = pf.fip_psth_inner_compute(nwb, nwb.df_trials[alignment_event].values,
298-
channel=channel, average=False, tw=offsets,
299-
censor=False, data_column=data_col)
296+
# copy df_trials, drops na values, sort trial by alignment event
297+
# sorting needed because censor in event_triggered_response sorts
298+
# this allows the trials to be matched with event_times
299+
df_trials = nwb.df_trials.dropna(subset=alignment_event, inplace=False)
300+
df_trials = df_trials.sort_values(by=alignment_event)
301+
302+
data = nwb.df_fip.query("event == @channel")
303+
align_timepoints = df_trials[alignment_event].values
304+
305+
etr = an.event_triggered_response(
306+
data,
307+
"timestamps",
308+
data_column,
309+
align_timepoints,
310+
t_start=offsets[0],
311+
t_end=offsets[1],
312+
output_sampling_rate=40,
313+
censor=censor,
314+
censor_times=None,
315+
)
300316

301317
avg_activity = etr.groupby("event_number").mean()
302318
avg_activity['trial'] = df_trials.trial.values
303-
avg_activity = avg_activity.rename(columns={data_col: output_col})
319+
avg_activity = avg_activity.rename(columns={data_column: output_col})
304320

305321
# Merge on 'trial'
306322
df_trials = df_trials.merge(avg_activity[['trial', output_col]], on='trial', how='left')

0 commit comments

Comments
 (0)