Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions src/aind_dynamic_foraging_basic_analysis/metrics/trial_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import aind_dynamic_foraging_data_utils.nwb_utils as nu
import aind_dynamic_foraging_models.logistic_regression.model as model
from aind_dynamic_foraging_data_utils import alignment as an
import numpy as np
import pandas as pd
import warnings

import aind_dynamic_foraging_basic_analysis.licks.annotation as a

Expand Down Expand Up @@ -234,3 +236,156 @@ def add_intertrial_licking(df_trials, df_licks):
df_trials["intertrial_choice"].rolling(WIN_DUR, min_periods=MIN_EVENTS, center=True).mean()
)
return df_trials


def get_average_signal_window_multi(
nwbs,
alignment_event,
offsets,
channel,
data_column='data_z',
censor=True,
output_col=None
):
"""
Wrapper for get_average_signal_window to process a
list of nwb objects and concatenate the results.

Parameters
----------
nwbs : list
List of nwb-like objects (each with .df_trials and .df_fip).
alignment_event : str
The event column in df_trials to align to.
offsets : list or tuple of float
[start, end] offsets (in seconds) relative to alignment_event.
channel : str
The value in df_fip['event'] to filter for.
data_col : str
Column in df_fip to extract (default 'data_z').
censor, censor important timepoints before and after aligned timepoints
output_col : str or None
Name for the new column. If None, will be generated automatically.

Returns
-------
pd.DataFrame
Concatenated DataFrame of all trials with the new signal window column.
"""
all_trials_avg_signal = []
for nwb in nwbs:
df_trials = get_average_signal_window(
nwb,
alignment_event=alignment_event,
offsets=offsets,
channel=channel,
data_column=data_column,
censor=censor,
output_col=output_col
)
cols_needed = ['trial', 'ses_idx', df_trials.columns[-1]]
all_trials_avg_signal.append(df_trials[cols_needed])
return pd.concat(all_trials_avg_signal, ignore_index=True)


def get_average_signal_window(
nwb,
alignment_event,
offsets,
channel,
data_column='data_z',
censor=True,
output_col=None,
):
"""
Returns a Series with the mean signal in a window around an alignment event,
for each trial, for each session and a specific signal (event).

Parameters
----------
nwb : nwb object (or nwb-like object)
nwb object with df_fip and df_trials attributes
alignment_event : str
The event column in df_trials to align to. must be given in_session, not in_trial
offsets: list or tuple of float
[start, end] offsets (in seconds) relative to alignment_event.
channel : str
The value in df_fip['event'] to filter for.
data_column : str
Column in df_fip to extract (default 'data_z').
censor, censor important timepoints before and after aligned timepoints
output_col : str or None
Name for the new column. If None, will be generated as
'<data_col>_<channel>_<start>_<end>_<alignment_event>'.


Returns
-------
df_trial: pd.DataFrame
DataFrame with a new column containing the mean signal
in the specified window for each trial.

EXAMPLE
*******************
df_trials = get_average_signal_window(nwb, alignment_event='choice_time_in_session',
offsets=[0.33,1],channel='G_0_dff-bright_mc-iso-IRLS',
data_column='data_z_norm')
"""

Comment thread
alexpiet marked this conversation as resolved.
# Check alignment_event ends with 'in_session'
if not alignment_event.endswith('in_session'):
raise ValueError(f"alignment_event '{alignment_event}' must end with 'in_session'.")

if not hasattr(nwb, "df_trials"):
raise ValueError("You need to compute df_trials: nwb_utils.create_trials_df(nwb)")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add a check here that df_fip has been computed, and reference the code to do it:
if not hasattr(nwb, "df_fip"):
raise AttributeError("You need to compute df_fip: nwb_utils.create_fib_df(nwb)")

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

if not hasattr(nwb, "df_fip"):
raise ValueError("You need to compute df_fip: nwb_utils.create_fib_df(nwb)")

# Check alignment_event is in df_trials columns
if alignment_event not in nwb.df_trials.columns:
raise ValueError(f"alignment_event '{alignment_event}' not found in df_trials columns.")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a check that data_column is in the dataframe

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

if channel not in nwb.df_fip.event.unique():
warnings.warn(f"{channel} channel not found in df_fip. Returning original df_trials.")
return nwb.df_trials

if data_column not in nwb.df_fip.columns:
raise ValueError(f"data column '{data_column}' not found in df_trials columns.")

# Get output column name
if output_col is None:
output_col = (
f"{data_column}_{channel}_{offsets[0]}_"
f"{offsets[1]}_{alignment_event.replace('_in_session','')}"
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you add a check that the output column isn't already in the dataframe?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user provides an output_col "Average_signal" but uses different data_columns, channels, and offsets, we should allow them to update it.

I would rather keep this as is if it's OK

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, thats fine


# copy df_trials, drops na values, sort trial by alignment event
# sorting needed because censor in event_triggered_response sorts
# this allows the trials to be matched with event_times
df_trials = nwb.df_trials.dropna(subset=alignment_event, inplace=False)
df_trials = df_trials.sort_values(by=alignment_event)

data = nwb.df_fip.query("event == @channel")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should check here that data is not empty. This could happen if the channel isn't in the df

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. added a check that the channel is in the df_fip.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should i just return an empty df_trials if the channel isn't present or give a value error if channel isn't in the df_fip?

just thinking of cases when i am running this for multiple NWB's, i might have some channels that exist for a particular NWB vs not. In that case, i might want to skip over that call rather than return an error.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would either:

  • generate an error
  • return the unchanged original df_trials, and then generate a warning about the missing channel

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, I guess I see that in the multiple session version you are concatenating the df_trials, so an empty dataframe is fine

align_timepoints = df_trials[alignment_event].values

etr = an.event_triggered_response(
data,
"timestamps",
data_column,
align_timepoints,
t_start=offsets[0],
t_end=offsets[1],
output_sampling_rate=40,
censor=censor,
censor_times=None,
)

avg_activity = etr.groupby("event_number").mean()
avg_activity['trial'] = df_trials.trial.values
avg_activity = avg_activity.rename(columns={data_column: output_col})

# Merge on 'trial'
df_trials = df_trials.merge(avg_activity[['trial', output_col]], on='trial', how='left')

return df_trials