Skip to content
126 changes: 106 additions & 20 deletions src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from aind_dynamic_foraging_data_utils import alignment as an
Expand Down Expand Up @@ -38,16 +39,19 @@ def plot_fip_psth_compare_alignments( # NOQA C901
plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time',
'right_reward_delivery_time'],'G_1_preprocessed')
"""
if not hasattr(nwb, "df_fip"):
# Check if nwb is a list, and if so, check only the first element for attributes and channel
nwb_to_check = nwb[0] if isinstance(nwb, list) else nwb

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why check only the first NWB? Checks are fast, so we should just check all of them

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. done. am testing and wll send PR back to you once I do that.


if not hasattr(nwb_to_check, "df_fip"):
print("You need to compute the df_fip first")
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
nwb.df_fip = nu.create_fib_df(nwb, tidy=True)
if not hasattr(nwb, "df_events"):
nwb_to_check.df_fip = nu.create_fib_df(nwb_to_check, tidy=True)
if not hasattr(nwb_to_check, "df_events"):
print("You need to compute the df_events first")
print("run `nwb.df_events = create_events_df(nwb)`")
nwb.df_events = nu.create_events_df(nwb)
nwb_to_check.df_events = nu.create_events_df(nwb_to_check)

if channel not in nwb.df_fip["event"].values:
if channel not in nwb_to_check.df_fip["event"].values:
print("channel {} not in df_fip".format(channel))

if isinstance(alignments, list):
Expand All @@ -60,6 +64,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
align_dict[a] = nwb.df_events.query("event == @a")["timestamps"].values
elif isinstance(alignments, dict):
align_dict = alignments
align_dict_flat = alignments.copy()
else:
print(
"alignments must be either a list of events in nwb.df_events, "
Expand All @@ -70,7 +75,9 @@ def plot_fip_psth_compare_alignments( # NOQA C901

censor_times = []
for key in align_dict:
censor_times.append(align_dict[key])
if isinstance(nwb, list):
align_dict_flat[key] = np.concatenate(align_dict[key])
censor_times.append(align_dict_flat[key])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you are doing here with align_dict_flat

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels wrong, because the censor times from one NWB shouldn't impact the censor times of another NWB

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right-- fixed.

censor_times = np.sort(np.concatenate(censor_times))

align_label = "Time (s)"
Expand All @@ -80,9 +87,18 @@ def plot_fip_psth_compare_alignments( # NOQA C901
colors = {**FIP_COLORS, **extra_colors}

for alignment in align_dict:
etr = fip_psth_inner_compute(
nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column
)
if isinstance(nwb, list):
# Compute etr for every NWB object in the list and average
etr = fip_psth_multiple_nwb_inner_compute(
nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column
)
session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb])

else:
etr = fip_psth_inner_compute(
nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column
)
session_id_title = nwb.session_id
fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column)

plt.legend()
Expand All @@ -93,7 +109,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
ax.set_xlim(tw)
ax.axvline(0, color="k", alpha=0.2)
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
ax.set_title(nwb.session_id, fontsize=STYLE["axis_fontsize"])
ax.set_title(session_id_title, fontsize=STYLE["axis_fontsize"])
plt.tight_layout()
return fig, ax

Expand Down Expand Up @@ -127,20 +143,27 @@ def plot_fip_psth_compare_channels(
********************
plot_fip_psth(nwb, 'goCue_start_time')
"""
if not hasattr(nwb, "df_fip"):
# Check if nwb is a list, and if so, check only the first element for attributes and channel
nwb_to_check = nwb[0] if isinstance(nwb, list) else nwb

if not hasattr(nwb_to_check, "df_fip"):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I would check all NWBs

print("You need to compute the df_fip first")
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
nwb.df_fip = nu.create_fib_df(nwb, tidy=True)
if not hasattr(nwb, "df_events"):
nwb_to_check.df_fip = nu.create_fib_df(nwb_to_check, tidy=True)
if not hasattr(nwb_to_check, "df_events"):
print("You need to compute the df_events first")
print("run `nwb.df_events = create_events_df(nwb)`")
nwb.df_events = nu.create_events_df(nwb)
nwb_to_check.df_events = nu.create_events_df(nwb_to_check)

if isinstance(align, str):
if align not in nwb.df_events["event"].values:
if align not in nwb_to_check.df_events["event"].values:
print("{} not found in the events table".format(align))
return
align_timepoints = nwb.df_events.query("event == @align")["timestamps"].values
if isinstance(nwb, list):
align_timepoints = [nwb_i.df_events.query("event == @align")["timestamps"].values
for nwb_i in nwb]
else:
align_timepoints = nwb_to_check.df_events.query("event == @align")["timestamps"].values
align_label = "Time from {} (s)".format(align)
else:
align_timepoints = align
Expand All @@ -151,9 +174,18 @@ def plot_fip_psth_compare_channels(

colors = [FIP_COLORS.get(c, "") for c in channels]
for dex, c in enumerate(channels):
if c in nwb.df_fip["event"].values:
etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw,
censor, data_column=data_column)
if c in nwb_to_check.df_fip["event"].values:
if isinstance(nwb, list):
# Compute etr for every NWB object in the list and average
etr = fip_psth_multiple_nwb_inner_compute(nwb, align_timepoints, c, True, tw,
censor, data_column=data_column)
session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb])

else:
etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw,
censor, data_column=data_column)
session_id_title = nwb.session_id

fip_psth_inner_plot(ax, etr, colors[dex], c, data_column)
else:
print("No data for channel: {}".format(c))
Expand All @@ -166,7 +198,7 @@ def plot_fip_psth_compare_channels(
ax.set_xlim(tw)
ax.axvline(0, color="k", alpha=0.2)
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
ax.set_title(nwb.session_id)
ax.set_title(session_id_title)
plt.tight_layout()
return fig, ax

Expand All @@ -188,6 +220,60 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column):
ax.plot(etr.index, etr[data_column], color=color, label=label)


def fip_psth_multiple_nwb_inner_compute(
nwbs_list,
align_timepoints,
channel,
average,
tw=[-1, 1],
censor=True,
censor_times=None,
data_column="data",
):
"""
helper function that computes the event triggered response
nwb, nwb object for the session of interest, should have df_fip attribute
align_timepoints, an iterable list of the timepoints to compute the ETR aligned to
channel, what channel in the df_fip dataframe to use
average(bool), whether to return the average, or all individual traces
tw, time window before and after each event
censor, censor important timepoints before and after aligned timepoints
censor_times, timepoints to censor
data_column (string), name of data column in nwb.df_fip

"""
etr_list = []
for (i, nwb) in enumerate(nwbs_list):
data = nwb.df_fip.query("event == @channel")
etr = an.event_triggered_response(
data,
"timestamps",
data_column,
align_timepoints[i],
t_start=tw[0],
t_end=tw[1],
output_sampling_rate=40,
censor=censor,
censor_times=censor_times,
)
etr['ses_idx'] = data.ses_idx.values[0]
etr_list.append(etr)

etr_all = pd.concat(etr_list, axis=0).reset_index(drop=True)
if average:
# Average within each ses_idx for each time point
mean_per_ses = etr_all.groupby(['ses_idx', 'time'])[data_column].mean().unstack('ses_idx')
# Grand mean: average across ses_idx for each time point
grand_mean = mean_per_ses.mean(axis=1)
# SEM over ses_idx for each time point
grand_sem = mean_per_ses.sem(axis=1)
# Combine into a DataFrame
result = grand_mean.to_frame(name=data_column)
result['sem'] = grand_sem
return result
return etr_all


def fip_psth_inner_compute(
nwb,
align_timepoints,
Expand Down