diff --git a/src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py b/src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py index 887c8db..e92c79d 100644 --- a/src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py +++ b/src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py @@ -3,6 +3,7 @@ """ import matplotlib.pyplot as plt +import pandas as pd import numpy as np from aind_dynamic_foraging_data_utils import alignment as an @@ -23,7 +24,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901 ): """ Compare the same FIP channel aligned to multiple event types - nwb, nwb object for the session + nwb, nwb object for the session, or list of nwbs alignments, either a list of event types in df_events, or a dictionary whose keys are event types and values are a list of timepoints channel, (str) the name of the FIP channel @@ -38,40 +39,60 @@ def plot_fip_psth_compare_alignments( # NOQA C901 plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time', 'right_reward_delivery_time'],'G_1_preprocessed') """ - if not hasattr(nwb, "df_fip"): - print("You need to compute the df_fip first") - print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") - nwb.df_fip = nu.create_fib_df(nwb, tidy=True) - if not hasattr(nwb, "df_events"): - print("You need to compute the df_events first") - print("run `nwb.df_events = create_events_df(nwb)`") - nwb.df_events = nu.create_events_df(nwb) - - if channel not in nwb.df_fip["event"].values: - print("channel {} not in df_fip".format(channel)) - - if isinstance(alignments, list): - align_dict = {} - for a in alignments: - if a not in nwb.df_events["event"].values: - print("{} not found in the events table".format(a)) - return - else: - align_dict[a] = nwb.df_events.query("event == @a")["timestamps"].values - elif isinstance(alignments, dict): - align_dict = alignments - else: - print( - "alignments must be either a list of events in nwb.df_events, " - + "or a dictionary where each key is an event type, " - + "and the value is a list of timepoints" - ) - return + # Check if nwb is a list, otherwise put it in a list to check + nwb_to_check = nwb if isinstance(nwb, list) else [nwb] + align_dict = {} + for nwb_i in nwb_to_check: + if not hasattr(nwb_i, "df_fip"): + print("You need to compute the df_fip first") + print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") + nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True) + if not hasattr(nwb_i, "df_events"): + print("You need to compute the df_events first") + print("run `nwb.df_events = create_events_df(nwb)`") + nwb_i.df_events = nu.create_events_df(nwb_i) + + if channel not in nwb_i.df_fip["event"].values: + print("channel {} not in df_fip".format(channel)) + + if isinstance(alignments, list): + for a in alignments: + if a not in nwb_i.df_events["event"].values: + print("{} not found in the events table".format(a)) + return + else: + align_vals = nwb_i.df_events.query("event == @a")["timestamps"].values + align_list = align_dict.get(a, []) + if len(nwb_to_check) > 1: + align_list.append(align_vals) + else: + align_list = align_vals + align_dict[a] = align_list + + elif isinstance(alignments, dict): + align_dict = alignments + else: + print( + "alignments must be either a list of events in nwb.df_events, " + + "or a dictionary where each key is an event type, " + + "and the value is a list of timepoints" + ) + return censor_times = [] - for key in align_dict: - censor_times.append(align_dict[key]) - censor_times = np.sort(np.concatenate(censor_times)) + if isinstance(nwb, list): + # For multiple NWBs, create a list of sorted, concatenated censor times for each NWB + for i in range(len(nwb)): + per_nwb_times = [] + for key in align_dict: + per_nwb_times.append(align_dict[key][i]) + per_nwb_times = np.sort(np.concatenate(per_nwb_times)) + censor_times.append(per_nwb_times) + else: + # For a single NWB, concatenate and sort all alignments + for key in align_dict: + censor_times.append(align_dict[key]) + censor_times = np.sort(np.concatenate(censor_times)) align_label = "Time (s)" if fig is None and ax is None: @@ -80,9 +101,18 @@ def plot_fip_psth_compare_alignments( # NOQA C901 colors = {**FIP_COLORS, **extra_colors} for alignment in align_dict: - etr = fip_psth_inner_compute( - nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column - ) + if isinstance(nwb, list): + # Compute etr for every NWB object in the list and average + etr = fip_psth_multiple_nwb_inner_compute( + nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column + ) + session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb]) + + else: + etr = fip_psth_inner_compute( + nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column + ) + session_id_title = nwb.session_id fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column) plt.legend() @@ -93,7 +123,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901 ax.set_xlim(tw) ax.axvline(0, color="k", alpha=0.2) ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"]) - ax.set_title(nwb.session_id, fontsize=STYLE["axis_fontsize"]) + ax.set_title(session_id_title, fontsize=STYLE["axis_fontsize"]) plt.tight_layout() return fig, ax @@ -127,33 +157,49 @@ def plot_fip_psth_compare_channels( ******************** plot_fip_psth(nwb, 'goCue_start_time') """ - if not hasattr(nwb, "df_fip"): - print("You need to compute the df_fip first") - print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") - nwb.df_fip = nu.create_fib_df(nwb, tidy=True) - if not hasattr(nwb, "df_events"): - print("You need to compute the df_events first") - print("run `nwb.df_events = create_events_df(nwb)`") - nwb.df_events = nu.create_events_df(nwb) - - if isinstance(align, str): - if align not in nwb.df_events["event"].values: - print("{} not found in the events table".format(align)) - return - align_timepoints = nwb.df_events.query("event == @align")["timestamps"].values - align_label = "Time from {} (s)".format(align) - else: - align_timepoints = align - align_label = "Time (s)" + # Check if nwb is a list, otherwise put it in a list to check + nwb_to_check = nwb if isinstance(nwb, list) else [nwb] + + align_timepoints = [] + for nwb_i in nwb_to_check: + if not hasattr(nwb_i, "df_fip"): + print("You need to compute the df_fip first") + print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") + nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True) + if not hasattr(nwb_i, "df_events"): + print("You need to compute the df_events first") + print("run `nwb.df_events = create_events_df(nwb)`") + nwb_i.df_events = nu.create_events_df(nwb_i) + + if isinstance(align, str): + if align not in nwb_i.df_events["event"].values: + print("{} not found in the events table".format(align)) + return + + align_timepoints.append(nwb_i.df_events.query("event == @align")["timestamps"].values) + align_label = "Time from {} (s)".format(align) + else: + align_timepoints = align + align_label = "Time (s)" if fig is None and ax is None: fig, ax = plt.subplots() colors = [FIP_COLORS.get(c, "") for c in channels] for dex, c in enumerate(channels): - if c in nwb.df_fip["event"].values: - etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw, - censor, data_column=data_column) + channel_exists = all(c in nwb_i.df_fip["event"].values for nwb_i in nwb_to_check) + if channel_exists: + if isinstance(nwb, list): + # Compute etr for every NWB object in the list and average + etr = fip_psth_multiple_nwb_inner_compute(nwb, align_timepoints, c, True, tw, + censor, data_column=data_column) + session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb]) + + else: + etr = fip_psth_inner_compute(nwb, np.squeeze(align_timepoints), c, True, tw, + censor, data_column=data_column) + session_id_title = nwb.session_id + fip_psth_inner_plot(ax, etr, colors[dex], c, data_column) else: print("No data for channel: {}".format(c)) @@ -166,7 +212,7 @@ def plot_fip_psth_compare_channels( ax.set_xlim(tw) ax.axvline(0, color="k", alpha=0.2) ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"]) - ax.set_title(nwb.session_id) + ax.set_title(session_id_title) plt.tight_layout() return fig, ax @@ -188,6 +234,68 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column): ax.plot(etr.index, etr[data_column], color=color, label=label) +def fip_psth_multiple_nwb_inner_compute( + nwbs_list, + align_timepoints, + channel, + average, + tw=[-1, 1], + censor=True, + censor_times=None, + data_column="data", +): + """ + helper function that computes the event triggered response for a list of nwbs + nwb, list of nwb objects to get PSTH's for + align_timepoints, a list of an iterable list of the timepoints to compute the ETR aligned to + (one for each nwb) + channel, what channel in the df_fip dataframe to use + average(bool), whether to return the average, or all individual traces + tw, time window before and after each event + censor, censor important timepoints before and after aligned timepoints + censor_times, timepoints to censor, would also be a list with the same lenght as nwb + data_column (string), name of data column in nwb.df_fip + + """ + + if censor_times is None: + censor_times = [None] * len(nwbs_list) + # check that alignment and nwbs_list match + assert len(nwbs_list) == len(align_timepoints), "Number of NWBs and align timepoints must match" + assert len(nwbs_list) == len(censor_times), "Number of NWBs and censor times must match" + + etr_list = [] + for (i, nwb) in enumerate(nwbs_list): + data = nwb.df_fip.query("event == @channel") + etr = an.event_triggered_response( + data, + "timestamps", + data_column, + align_timepoints[i], + t_start=tw[0], + t_end=tw[1], + output_sampling_rate=40, + censor=censor, + censor_times=censor_times[i], + ) + etr['ses_idx'] = data.ses_idx.values[0] + etr_list.append(etr) + + etr_all = pd.concat(etr_list, axis=0).reset_index(drop=True) + if average: + # Average within each ses_idx for each time point + mean_per_ses = etr_all.groupby(['ses_idx', 'time'])[data_column].mean().unstack('ses_idx') + # Grand mean: average across ses_idx for each time point + grand_mean = mean_per_ses.mean(axis=1) + # SEM over ses_idx for each time point + grand_sem = mean_per_ses.sem(axis=1) + # Combine into a DataFrame + result = grand_mean.to_frame(name=data_column) + result['sem'] = grand_sem + return result + return etr_all + + def fip_psth_inner_compute( nwb, align_timepoints,