Skip to content

Commit 7d18482

Browse files
fixed bug that accidentally filtered out ignore trials
1 parent 1986012 commit 7d18482

1 file changed

Lines changed: 9 additions & 15 deletions

File tree

src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
Tools for computing trial by trial metrics
3-
df_trials = compute_trial_metrics(nwb)
4-
df_trials = compute_bias(nwb)
2+
Tools for computing trial by trial metrics
3+
df_trials = compute_trial_metrics(nwb)
4+
df_trials = compute_bias(nwb)
55
66
"""
77

@@ -239,13 +239,7 @@ def add_intertrial_licking(df_trials, df_licks):
239239

240240

241241
def get_average_signal_window_multi(
242-
nwbs,
243-
alignment_event,
244-
offsets,
245-
channel,
246-
data_column='data_z',
247-
censor=True,
248-
output_col=None
242+
nwbs, alignment_event, offsets, channel, data_column="data_z", censor=True, output_col=None
249243
):
250244
"""
251245
Wrapper for get_average_signal_window to process a
@@ -280,7 +274,7 @@ def get_average_signal_window_multi(
280274
channel=channel,
281275
data_column=data_column,
282276
censor=censor,
283-
output_col=output_col
277+
output_col=output_col,
284278
)
285279
nwb.df_trials = df_trials
286280
return nwbs
@@ -291,7 +285,7 @@ def get_average_signal_window(
291285
alignment_event,
292286
offsets,
293287
channel,
294-
data_column='data_z',
288+
data_column="data_z",
295289
censor=True,
296290
output_col=None,
297291
):
@@ -331,7 +325,7 @@ def get_average_signal_window(
331325
"""
332326

333327
# Check alignment_event ends with 'in_session'
334-
if not alignment_event.endswith('in_session'):
328+
if not alignment_event.endswith("in_session"):
335329
raise ValueError(f"alignment_event '{alignment_event}' must end with 'in_session'.")
336330

337331
if not hasattr(nwb, "df_trials"):
@@ -380,10 +374,10 @@ def get_average_signal_window(
380374
)
381375

382376
avg_activity = etr.groupby("event_number").mean()
383-
avg_activity['trial'] = df_trials.trial.values
377+
avg_activity["trial"] = df_trials.trial.values
384378
avg_activity = avg_activity.rename(columns={data_column: output_col})
385379

386380
# Merge on 'trial'
387-
df_trials = df_trials.merge(avg_activity[['trial', output_col]], on='trial', how='left')
381+
df_trials = nwb.df_trials.merge(avg_activity[["trial", output_col]], on="trial", how="left")
388382

389383
return df_trials

0 commit comments

Comments
 (0)