-
Notifications
You must be signed in to change notification settings - Fork 0
Psth multiple nwbs #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4083931
3505139
db1bd13
886c250
5a839b4
d25325d
3e5b69f
fba8197
950b690
ffde75f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| """ | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import pandas as pd | ||
| import numpy as np | ||
|
|
||
| from aind_dynamic_foraging_data_utils import alignment as an | ||
|
|
@@ -23,7 +24,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| ): | ||
| """ | ||
| Compare the same FIP channel aligned to multiple event types | ||
| nwb, nwb object for the session | ||
| nwb, nwb object for the session, or list of nwbs | ||
| alignments, either a list of event types in df_events, or a dictionary | ||
| whose keys are event types and values are a list of timepoints | ||
| channel, (str) the name of the FIP channel | ||
|
|
@@ -38,40 +39,60 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time', | ||
| 'right_reward_delivery_time'],'G_1_preprocessed') | ||
| """ | ||
| if not hasattr(nwb, "df_fip"): | ||
| print("You need to compute the df_fip first") | ||
| print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") | ||
| nwb.df_fip = nu.create_fib_df(nwb, tidy=True) | ||
| if not hasattr(nwb, "df_events"): | ||
| print("You need to compute the df_events first") | ||
| print("run `nwb.df_events = create_events_df(nwb)`") | ||
| nwb.df_events = nu.create_events_df(nwb) | ||
|
|
||
| if channel not in nwb.df_fip["event"].values: | ||
| print("channel {} not in df_fip".format(channel)) | ||
|
|
||
| if isinstance(alignments, list): | ||
| align_dict = {} | ||
| for a in alignments: | ||
| if a not in nwb.df_events["event"].values: | ||
| print("{} not found in the events table".format(a)) | ||
| return | ||
| else: | ||
| align_dict[a] = nwb.df_events.query("event == @a")["timestamps"].values | ||
| elif isinstance(alignments, dict): | ||
| align_dict = alignments | ||
| else: | ||
| print( | ||
| "alignments must be either a list of events in nwb.df_events, " | ||
| + "or a dictionary where each key is an event type, " | ||
| + "and the value is a list of timepoints" | ||
| ) | ||
| return | ||
| # Check if nwb is a list, otherwise put it in a list to check | ||
| nwb_to_check = nwb if isinstance(nwb, list) else [nwb] | ||
| align_dict = {} | ||
| for nwb_i in nwb_to_check: | ||
| if not hasattr(nwb_i, "df_fip"): | ||
| print("You need to compute the df_fip first") | ||
| print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") | ||
| nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True) | ||
| if not hasattr(nwb_i, "df_events"): | ||
| print("You need to compute the df_events first") | ||
| print("run `nwb.df_events = create_events_df(nwb)`") | ||
| nwb_i.df_events = nu.create_events_df(nwb_i) | ||
|
|
||
| if channel not in nwb_i.df_fip["event"].values: | ||
| print("channel {} not in df_fip".format(channel)) | ||
|
|
||
| if isinstance(alignments, list): | ||
| for a in alignments: | ||
| if a not in nwb_i.df_events["event"].values: | ||
| print("{} not found in the events table".format(a)) | ||
| return | ||
| else: | ||
| align_vals = nwb_i.df_events.query("event == @a")["timestamps"].values | ||
| align_list = align_dict.get(a, []) | ||
| if len(nwb_to_check) > 1: | ||
| align_list.append(align_vals) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whats going on here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are two cases: one single nwb coming into multiple nwb's coming into 'plot_fip_psth_compare_alignments', in which case the align_dict should have a string key and a list of lists (each list an align_val for each NWB) If I don't separate out the two cases when looking at alignments given as a column, the align_dict for a single NWB will have a string key and a nested list (list of a single list). This will cause a problem in line 95, because np.concatenate will fail. |
||
| else: | ||
| align_list = align_vals | ||
| align_dict[a] = align_list | ||
|
|
||
| elif isinstance(alignments, dict): | ||
| align_dict = alignments | ||
| else: | ||
| print( | ||
| "alignments must be either a list of events in nwb.df_events, " | ||
| + "or a dictionary where each key is an event type, " | ||
| + "and the value is a list of timepoints" | ||
| ) | ||
| return | ||
|
|
||
| censor_times = [] | ||
| for key in align_dict: | ||
| censor_times.append(align_dict[key]) | ||
| censor_times = np.sort(np.concatenate(censor_times)) | ||
| if isinstance(nwb, list): | ||
| # For multiple NWBs, create a list of sorted, concatenated censor times for each NWB | ||
| for i in range(len(nwb)): | ||
| per_nwb_times = [] | ||
| for key in align_dict: | ||
| per_nwb_times.append(align_dict[key][i]) | ||
| per_nwb_times = np.sort(np.concatenate(per_nwb_times)) | ||
| censor_times.append(per_nwb_times) | ||
| else: | ||
| # For a single NWB, concatenate and sort all alignments | ||
| for key in align_dict: | ||
| censor_times.append(align_dict[key]) | ||
| censor_times = np.sort(np.concatenate(censor_times)) | ||
|
|
||
| align_label = "Time (s)" | ||
| if fig is None and ax is None: | ||
|
|
@@ -80,9 +101,18 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| colors = {**FIP_COLORS, **extra_colors} | ||
|
|
||
| for alignment in align_dict: | ||
| etr = fip_psth_inner_compute( | ||
| nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column | ||
| ) | ||
| if isinstance(nwb, list): | ||
| # Compute etr for every NWB object in the list and average | ||
| etr = fip_psth_multiple_nwb_inner_compute( | ||
| nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column | ||
| ) | ||
| session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb]) | ||
|
|
||
| else: | ||
| etr = fip_psth_inner_compute( | ||
| nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column | ||
| ) | ||
| session_id_title = nwb.session_id | ||
| fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column) | ||
|
|
||
| plt.legend() | ||
|
|
@@ -93,7 +123,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| ax.set_xlim(tw) | ||
| ax.axvline(0, color="k", alpha=0.2) | ||
| ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"]) | ||
| ax.set_title(nwb.session_id, fontsize=STYLE["axis_fontsize"]) | ||
| ax.set_title(session_id_title, fontsize=STYLE["axis_fontsize"]) | ||
| plt.tight_layout() | ||
| return fig, ax | ||
|
|
||
|
|
@@ -127,33 +157,49 @@ def plot_fip_psth_compare_channels( | |
| ******************** | ||
| plot_fip_psth(nwb, 'goCue_start_time') | ||
| """ | ||
| if not hasattr(nwb, "df_fip"): | ||
| print("You need to compute the df_fip first") | ||
| print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") | ||
| nwb.df_fip = nu.create_fib_df(nwb, tidy=True) | ||
| if not hasattr(nwb, "df_events"): | ||
| print("You need to compute the df_events first") | ||
| print("run `nwb.df_events = create_events_df(nwb)`") | ||
| nwb.df_events = nu.create_events_df(nwb) | ||
|
|
||
| if isinstance(align, str): | ||
| if align not in nwb.df_events["event"].values: | ||
| print("{} not found in the events table".format(align)) | ||
| return | ||
| align_timepoints = nwb.df_events.query("event == @align")["timestamps"].values | ||
| align_label = "Time from {} (s)".format(align) | ||
| else: | ||
| align_timepoints = align | ||
| align_label = "Time (s)" | ||
| # Check if nwb is a list, otherwise put it in a list to check | ||
| nwb_to_check = nwb if isinstance(nwb, list) else [nwb] | ||
|
|
||
| align_timepoints = [] | ||
| for nwb_i in nwb_to_check: | ||
| if not hasattr(nwb_i, "df_fip"): | ||
| print("You need to compute the df_fip first") | ||
| print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") | ||
| nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True) | ||
| if not hasattr(nwb_i, "df_events"): | ||
| print("You need to compute the df_events first") | ||
| print("run `nwb.df_events = create_events_df(nwb)`") | ||
| nwb_i.df_events = nu.create_events_df(nwb_i) | ||
|
|
||
| if isinstance(align, str): | ||
| if align not in nwb_i.df_events["event"].values: | ||
| print("{} not found in the events table".format(align)) | ||
| return | ||
|
|
||
| align_timepoints.append(nwb_i.df_events.query("event == @align")["timestamps"].values) | ||
| align_label = "Time from {} (s)".format(align) | ||
| else: | ||
| align_timepoints = align | ||
| align_label = "Time (s)" | ||
|
|
||
| if fig is None and ax is None: | ||
| fig, ax = plt.subplots() | ||
|
|
||
| colors = [FIP_COLORS.get(c, "") for c in channels] | ||
| for dex, c in enumerate(channels): | ||
| if c in nwb.df_fip["event"].values: | ||
| etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw, | ||
| censor, data_column=data_column) | ||
| channel_exists = all(c in nwb_i.df_fip["event"].values for nwb_i in nwb_to_check) | ||
| if channel_exists: | ||
| if isinstance(nwb, list): | ||
| # Compute etr for every NWB object in the list and average | ||
| etr = fip_psth_multiple_nwb_inner_compute(nwb, align_timepoints, c, True, tw, | ||
| censor, data_column=data_column) | ||
| session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb]) | ||
|
|
||
| else: | ||
| etr = fip_psth_inner_compute(nwb, np.squeeze(align_timepoints), c, True, tw, | ||
| censor, data_column=data_column) | ||
| session_id_title = nwb.session_id | ||
|
|
||
| fip_psth_inner_plot(ax, etr, colors[dex], c, data_column) | ||
| else: | ||
| print("No data for channel: {}".format(c)) | ||
|
|
@@ -166,7 +212,7 @@ def plot_fip_psth_compare_channels( | |
| ax.set_xlim(tw) | ||
| ax.axvline(0, color="k", alpha=0.2) | ||
| ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"]) | ||
| ax.set_title(nwb.session_id) | ||
| ax.set_title(session_id_title) | ||
| plt.tight_layout() | ||
| return fig, ax | ||
|
|
||
|
|
@@ -188,6 +234,68 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column): | |
| ax.plot(etr.index, etr[data_column], color=color, label=label) | ||
|
|
||
|
|
||
| def fip_psth_multiple_nwb_inner_compute( | ||
| nwbs_list, | ||
| align_timepoints, | ||
| channel, | ||
| average, | ||
| tw=[-1, 1], | ||
| censor=True, | ||
| censor_times=None, | ||
| data_column="data", | ||
| ): | ||
| """ | ||
| helper function that computes the event triggered response for a list of nwbs | ||
| nwb, list of nwb objects to get PSTH's for | ||
| align_timepoints, a list of an iterable list of the timepoints to compute the ETR aligned to | ||
| (one for each nwb) | ||
| channel, what channel in the df_fip dataframe to use | ||
| average(bool), whether to return the average, or all individual traces | ||
| tw, time window before and after each event | ||
| censor, censor important timepoints before and after aligned timepoints | ||
| censor_times, timepoints to censor, would also be a list with the same lenght as nwb | ||
| data_column (string), name of data column in nwb.df_fip | ||
|
|
||
| """ | ||
|
|
||
| if censor_times is None: | ||
| censor_times = [None] * len(nwbs_list) | ||
| # check that alignment and nwbs_list match | ||
| assert len(nwbs_list) == len(align_timepoints), "Number of NWBs and align timepoints must match" | ||
| assert len(nwbs_list) == len(censor_times), "Number of NWBs and censor times must match" | ||
|
|
||
| etr_list = [] | ||
| for (i, nwb) in enumerate(nwbs_list): | ||
| data = nwb.df_fip.query("event == @channel") | ||
| etr = an.event_triggered_response( | ||
| data, | ||
| "timestamps", | ||
| data_column, | ||
| align_timepoints[i], | ||
| t_start=tw[0], | ||
| t_end=tw[1], | ||
| output_sampling_rate=40, | ||
| censor=censor, | ||
| censor_times=censor_times[i], | ||
| ) | ||
| etr['ses_idx'] = data.ses_idx.values[0] | ||
| etr_list.append(etr) | ||
|
|
||
| etr_all = pd.concat(etr_list, axis=0).reset_index(drop=True) | ||
| if average: | ||
| # Average within each ses_idx for each time point | ||
| mean_per_ses = etr_all.groupby(['ses_idx', 'time'])[data_column].mean().unstack('ses_idx') | ||
| # Grand mean: average across ses_idx for each time point | ||
| grand_mean = mean_per_ses.mean(axis=1) | ||
| # SEM over ses_idx for each time point | ||
| grand_sem = mean_per_ses.sem(axis=1) | ||
| # Combine into a DataFrame | ||
| result = grand_mean.to_frame(name=data_column) | ||
| result['sem'] = grand_sem | ||
| return result | ||
| return etr_all | ||
|
|
||
|
|
||
| def fip_psth_inner_compute( | ||
| nwb, | ||
| align_timepoints, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why save things to align_vals, then use
.get()? You already checked above if a is in df_eventsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to not repeat it for the if/else logic below.
.get is checking align_dict, not checking nwb_i.df_events.