Skip to content
226 changes: 167 additions & 59 deletions src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from aind_dynamic_foraging_data_utils import alignment as an
Expand All @@ -23,7 +24,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
):
"""
Compare the same FIP channel aligned to multiple event types
nwb, nwb object for the session
nwb, nwb object for the session, or list of nwbs
alignments, either a list of event types in df_events, or a dictionary
whose keys are event types and values are a list of timepoints
channel, (str) the name of the FIP channel
Expand All @@ -38,40 +39,60 @@ def plot_fip_psth_compare_alignments( # NOQA C901
plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time',
'right_reward_delivery_time'],'G_1_preprocessed')
"""
if not hasattr(nwb, "df_fip"):
print("You need to compute the df_fip first")
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
nwb.df_fip = nu.create_fib_df(nwb, tidy=True)
if not hasattr(nwb, "df_events"):
print("You need to compute the df_events first")
print("run `nwb.df_events = create_events_df(nwb)`")
nwb.df_events = nu.create_events_df(nwb)

if channel not in nwb.df_fip["event"].values:
print("channel {} not in df_fip".format(channel))

if isinstance(alignments, list):
align_dict = {}
for a in alignments:
if a not in nwb.df_events["event"].values:
print("{} not found in the events table".format(a))
return
else:
align_dict[a] = nwb.df_events.query("event == @a")["timestamps"].values
elif isinstance(alignments, dict):
align_dict = alignments
else:
print(
"alignments must be either a list of events in nwb.df_events, "
+ "or a dictionary where each key is an event type, "
+ "and the value is a list of timepoints"
)
return
# Check if nwb is a list, otherwise put it in a list to check
nwb_to_check = nwb if isinstance(nwb, list) else [nwb]
align_dict = {}
for nwb_i in nwb_to_check:
if not hasattr(nwb_i, "df_fip"):
print("You need to compute the df_fip first")
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True)
if not hasattr(nwb_i, "df_events"):
print("You need to compute the df_events first")
print("run `nwb.df_events = create_events_df(nwb)`")
nwb_i.df_events = nu.create_events_df(nwb_i)

if channel not in nwb_i.df_fip["event"].values:
print("channel {} not in df_fip".format(channel))

if isinstance(alignments, list):
for a in alignments:
if a not in nwb_i.df_events["event"].values:
print("{} not found in the events table".format(a))
return
else:
align_vals = nwb_i.df_events.query("event == @a")["timestamps"].values
align_list = align_dict.get(a, [])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why save things to align_vals, then use .get()? You already checked above if a is in df_events

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to not repeat it for the if/else logic below.

.get is checking align_dict, not checking nwb_i.df_events.

if len(nwb_to_check) > 1:
align_list.append(align_vals)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats going on here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two cases:

one single nwb coming into plot_fip_psth_compare_alignments, in which case the align_dict should have a string key and a list of the align_vals.

multiple nwb's coming into 'plot_fip_psth_compare_alignments', in which case the align_dict should have a string key and a list of lists (each list an align_val for each NWB)

If I don't separate out the two cases when looking at alignments given as a column, the align_dict for a single NWB will have a string key and a nested list (list of a single list). This will cause a problem in line 95, because np.concatenate will fail.

else:
align_list = align_vals
align_dict[a] = align_list

elif isinstance(alignments, dict):
align_dict = alignments
else:
print(
"alignments must be either a list of events in nwb.df_events, "
+ "or a dictionary where each key is an event type, "
+ "and the value is a list of timepoints"
)
return

censor_times = []
for key in align_dict:
censor_times.append(align_dict[key])
censor_times = np.sort(np.concatenate(censor_times))
if isinstance(nwb, list):
# For multiple NWBs, create a list of sorted, concatenated censor times for each NWB
for i in range(len(nwb)):
per_nwb_times = []
for key in align_dict:
per_nwb_times.append(align_dict[key][i])
per_nwb_times = np.sort(np.concatenate(per_nwb_times))
censor_times.append(per_nwb_times)
else:
# For a single NWB, concatenate and sort all alignments
for key in align_dict:
censor_times.append(align_dict[key])
censor_times = np.sort(np.concatenate(censor_times))

align_label = "Time (s)"
if fig is None and ax is None:
Expand All @@ -80,9 +101,18 @@ def plot_fip_psth_compare_alignments( # NOQA C901
colors = {**FIP_COLORS, **extra_colors}

for alignment in align_dict:
etr = fip_psth_inner_compute(
nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column
)
if isinstance(nwb, list):
# Compute etr for every NWB object in the list and average
etr = fip_psth_multiple_nwb_inner_compute(
nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column
)
session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb])

else:
etr = fip_psth_inner_compute(
nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column
)
session_id_title = nwb.session_id
fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column)

plt.legend()
Expand All @@ -93,7 +123,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
ax.set_xlim(tw)
ax.axvline(0, color="k", alpha=0.2)
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
ax.set_title(nwb.session_id, fontsize=STYLE["axis_fontsize"])
ax.set_title(session_id_title, fontsize=STYLE["axis_fontsize"])
plt.tight_layout()
return fig, ax

Expand Down Expand Up @@ -127,33 +157,49 @@ def plot_fip_psth_compare_channels(
********************
plot_fip_psth(nwb, 'goCue_start_time')
"""
if not hasattr(nwb, "df_fip"):
print("You need to compute the df_fip first")
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
nwb.df_fip = nu.create_fib_df(nwb, tidy=True)
if not hasattr(nwb, "df_events"):
print("You need to compute the df_events first")
print("run `nwb.df_events = create_events_df(nwb)`")
nwb.df_events = nu.create_events_df(nwb)

if isinstance(align, str):
if align not in nwb.df_events["event"].values:
print("{} not found in the events table".format(align))
return
align_timepoints = nwb.df_events.query("event == @align")["timestamps"].values
align_label = "Time from {} (s)".format(align)
else:
align_timepoints = align
align_label = "Time (s)"
# Check if nwb is a list, otherwise put it in a list to check
nwb_to_check = nwb if isinstance(nwb, list) else [nwb]

align_timepoints = []
for nwb_i in nwb_to_check:
if not hasattr(nwb_i, "df_fip"):
print("You need to compute the df_fip first")
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True)
if not hasattr(nwb_i, "df_events"):
print("You need to compute the df_events first")
print("run `nwb.df_events = create_events_df(nwb)`")
nwb_i.df_events = nu.create_events_df(nwb_i)

if isinstance(align, str):
if align not in nwb_i.df_events["event"].values:
print("{} not found in the events table".format(align))
return

align_timepoints.append(nwb_i.df_events.query("event == @align")["timestamps"].values)
align_label = "Time from {} (s)".format(align)
else:
align_timepoints = align
align_label = "Time (s)"

if fig is None and ax is None:
fig, ax = plt.subplots()

colors = [FIP_COLORS.get(c, "") for c in channels]
for dex, c in enumerate(channels):
if c in nwb.df_fip["event"].values:
etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw,
censor, data_column=data_column)
channel_exists = all(c in nwb_i.df_fip["event"].values for nwb_i in nwb_to_check)
if channel_exists:
if isinstance(nwb, list):
# Compute etr for every NWB object in the list and average
etr = fip_psth_multiple_nwb_inner_compute(nwb, align_timepoints, c, True, tw,
censor, data_column=data_column)
session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb])

else:
etr = fip_psth_inner_compute(nwb, np.squeeze(align_timepoints), c, True, tw,
censor, data_column=data_column)
session_id_title = nwb.session_id

fip_psth_inner_plot(ax, etr, colors[dex], c, data_column)
else:
print("No data for channel: {}".format(c))
Expand All @@ -166,7 +212,7 @@ def plot_fip_psth_compare_channels(
ax.set_xlim(tw)
ax.axvline(0, color="k", alpha=0.2)
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
ax.set_title(nwb.session_id)
ax.set_title(session_id_title)
plt.tight_layout()
return fig, ax

Expand All @@ -188,6 +234,68 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column):
ax.plot(etr.index, etr[data_column], color=color, label=label)


def fip_psth_multiple_nwb_inner_compute(
nwbs_list,
align_timepoints,
channel,
average,
tw=[-1, 1],
censor=True,
censor_times=None,
data_column="data",
):
"""
helper function that computes the event triggered response for a list of nwbs
nwb, list of nwb objects to get PSTH's for
align_timepoints, a list of an iterable list of the timepoints to compute the ETR aligned to
(one for each nwb)
channel, what channel in the df_fip dataframe to use
average(bool), whether to return the average, or all individual traces
tw, time window before and after each event
censor, censor important timepoints before and after aligned timepoints
censor_times, timepoints to censor, would also be a list with the same lenght as nwb
data_column (string), name of data column in nwb.df_fip

"""

if censor_times is None:
censor_times = [None] * len(nwbs_list)
# check that alignment and nwbs_list match
assert len(nwbs_list) == len(align_timepoints), "Number of NWBs and align timepoints must match"
assert len(nwbs_list) == len(censor_times), "Number of NWBs and censor times must match"

etr_list = []
for (i, nwb) in enumerate(nwbs_list):
data = nwb.df_fip.query("event == @channel")
etr = an.event_triggered_response(
data,
"timestamps",
data_column,
align_timepoints[i],
t_start=tw[0],
t_end=tw[1],
output_sampling_rate=40,
censor=censor,
censor_times=censor_times[i],
)
etr['ses_idx'] = data.ses_idx.values[0]
etr_list.append(etr)

etr_all = pd.concat(etr_list, axis=0).reset_index(drop=True)
if average:
# Average within each ses_idx for each time point
mean_per_ses = etr_all.groupby(['ses_idx', 'time'])[data_column].mean().unstack('ses_idx')
# Grand mean: average across ses_idx for each time point
grand_mean = mean_per_ses.mean(axis=1)
# SEM over ses_idx for each time point
grand_sem = mean_per_ses.sem(axis=1)
# Combine into a DataFrame
result = grand_mean.to_frame(name=data_column)
result['sem'] = grand_sem
return result
return etr_all


def fip_psth_inner_compute(
nwb,
align_timepoints,
Expand Down