-
Notifications
You must be signed in to change notification settings - Fork 0
Psth multiple nwbs #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
4083931
3505139
db1bd13
886c250
5a839b4
d25325d
3e5b69f
fba8197
950b690
ffde75f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| """ | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import pandas as pd | ||
| import numpy as np | ||
|
|
||
| from aind_dynamic_foraging_data_utils import alignment as an | ||
|
|
@@ -38,16 +39,19 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time', | ||
| 'right_reward_delivery_time'],'G_1_preprocessed') | ||
| """ | ||
| if not hasattr(nwb, "df_fip"): | ||
| # Check if nwb is a list, and if so, check only the first element for attributes and channel | ||
| nwb_to_check = nwb[0] if isinstance(nwb, list) else nwb | ||
|
|
||
| if not hasattr(nwb_to_check, "df_fip"): | ||
| print("You need to compute the df_fip first") | ||
| print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") | ||
| nwb.df_fip = nu.create_fib_df(nwb, tidy=True) | ||
| if not hasattr(nwb, "df_events"): | ||
| nwb_to_check.df_fip = nu.create_fib_df(nwb_to_check, tidy=True) | ||
| if not hasattr(nwb_to_check, "df_events"): | ||
| print("You need to compute the df_events first") | ||
| print("run `nwb.df_events = create_events_df(nwb)`") | ||
| nwb.df_events = nu.create_events_df(nwb) | ||
| nwb_to_check.df_events = nu.create_events_df(nwb_to_check) | ||
|
|
||
| if channel not in nwb.df_fip["event"].values: | ||
| if channel not in nwb_to_check.df_fip["event"].values: | ||
| print("channel {} not in df_fip".format(channel)) | ||
|
|
||
| if isinstance(alignments, list): | ||
|
|
@@ -60,6 +64,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| align_dict[a] = nwb.df_events.query("event == @a")["timestamps"].values | ||
| elif isinstance(alignments, dict): | ||
| align_dict = alignments | ||
| align_dict_flat = alignments.copy() | ||
| else: | ||
| print( | ||
| "alignments must be either a list of events in nwb.df_events, " | ||
|
|
@@ -70,7 +75,9 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
|
|
||
| censor_times = [] | ||
| for key in align_dict: | ||
| censor_times.append(align_dict[key]) | ||
| if isinstance(nwb, list): | ||
| align_dict_flat[key] = np.concatenate(align_dict[key]) | ||
| censor_times.append(align_dict_flat[key]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what you are doing here with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This feels wrong, because the censor times from one NWB shouldn't impact the censor times of another NWB
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're right-- fixed. |
||
| censor_times = np.sort(np.concatenate(censor_times)) | ||
|
|
||
| align_label = "Time (s)" | ||
|
|
@@ -80,9 +87,18 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| colors = {**FIP_COLORS, **extra_colors} | ||
|
|
||
| for alignment in align_dict: | ||
| etr = fip_psth_inner_compute( | ||
| nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column | ||
| ) | ||
| if isinstance(nwb, list): | ||
| # Compute etr for every NWB object in the list and average | ||
| etr = fip_psth_multiple_nwb_inner_compute( | ||
| nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column | ||
| ) | ||
| session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb]) | ||
|
|
||
| else: | ||
| etr = fip_psth_inner_compute( | ||
| nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column | ||
| ) | ||
| session_id_title = nwb.session_id | ||
| fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column) | ||
|
|
||
| plt.legend() | ||
|
|
@@ -93,7 +109,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901 | |
| ax.set_xlim(tw) | ||
| ax.axvline(0, color="k", alpha=0.2) | ||
| ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"]) | ||
| ax.set_title(nwb.session_id, fontsize=STYLE["axis_fontsize"]) | ||
| ax.set_title(session_id_title, fontsize=STYLE["axis_fontsize"]) | ||
| plt.tight_layout() | ||
| return fig, ax | ||
|
|
||
|
|
@@ -127,20 +143,27 @@ def plot_fip_psth_compare_channels( | |
| ******************** | ||
| plot_fip_psth(nwb, 'goCue_start_time') | ||
| """ | ||
| if not hasattr(nwb, "df_fip"): | ||
| # Check if nwb is a list, and if so, check only the first element for attributes and channel | ||
| nwb_to_check = nwb[0] if isinstance(nwb, list) else nwb | ||
|
|
||
| if not hasattr(nwb_to_check, "df_fip"): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, I would check all NWBs |
||
| print("You need to compute the df_fip first") | ||
| print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`") | ||
| nwb.df_fip = nu.create_fib_df(nwb, tidy=True) | ||
| if not hasattr(nwb, "df_events"): | ||
| nwb_to_check.df_fip = nu.create_fib_df(nwb_to_check, tidy=True) | ||
| if not hasattr(nwb_to_check, "df_events"): | ||
| print("You need to compute the df_events first") | ||
| print("run `nwb.df_events = create_events_df(nwb)`") | ||
| nwb.df_events = nu.create_events_df(nwb) | ||
| nwb_to_check.df_events = nu.create_events_df(nwb_to_check) | ||
|
|
||
| if isinstance(align, str): | ||
| if align not in nwb.df_events["event"].values: | ||
| if align not in nwb_to_check.df_events["event"].values: | ||
| print("{} not found in the events table".format(align)) | ||
| return | ||
| align_timepoints = nwb.df_events.query("event == @align")["timestamps"].values | ||
| if isinstance(nwb, list): | ||
| align_timepoints = [nwb_i.df_events.query("event == @align")["timestamps"].values | ||
| for nwb_i in nwb] | ||
| else: | ||
| align_timepoints = nwb_to_check.df_events.query("event == @align")["timestamps"].values | ||
| align_label = "Time from {} (s)".format(align) | ||
| else: | ||
| align_timepoints = align | ||
|
|
@@ -151,9 +174,18 @@ def plot_fip_psth_compare_channels( | |
|
|
||
| colors = [FIP_COLORS.get(c, "") for c in channels] | ||
| for dex, c in enumerate(channels): | ||
| if c in nwb.df_fip["event"].values: | ||
| etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw, | ||
| censor, data_column=data_column) | ||
| if c in nwb_to_check.df_fip["event"].values: | ||
| if isinstance(nwb, list): | ||
| # Compute etr for every NWB object in the list and average | ||
| etr = fip_psth_multiple_nwb_inner_compute(nwb, align_timepoints, c, True, tw, | ||
| censor, data_column=data_column) | ||
| session_id_title = ', '.join([nwb_i.session_id for nwb_i in nwb]) | ||
|
|
||
| else: | ||
| etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw, | ||
| censor, data_column=data_column) | ||
| session_id_title = nwb.session_id | ||
|
|
||
| fip_psth_inner_plot(ax, etr, colors[dex], c, data_column) | ||
| else: | ||
| print("No data for channel: {}".format(c)) | ||
|
|
@@ -166,7 +198,7 @@ def plot_fip_psth_compare_channels( | |
| ax.set_xlim(tw) | ||
| ax.axvline(0, color="k", alpha=0.2) | ||
| ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"]) | ||
| ax.set_title(nwb.session_id) | ||
| ax.set_title(session_id_title) | ||
| plt.tight_layout() | ||
| return fig, ax | ||
|
|
||
|
|
@@ -188,6 +220,60 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column): | |
| ax.plot(etr.index, etr[data_column], color=color, label=label) | ||
|
|
||
|
|
||
| def fip_psth_multiple_nwb_inner_compute( | ||
| nwbs_list, | ||
| align_timepoints, | ||
| channel, | ||
| average, | ||
| tw=[-1, 1], | ||
| censor=True, | ||
| censor_times=None, | ||
| data_column="data", | ||
| ): | ||
| """ | ||
| helper function that computes the event triggered response | ||
| nwb, nwb object for the session of interest, should have df_fip attribute | ||
| align_timepoints, an iterable list of the timepoints to compute the ETR aligned to | ||
| channel, what channel in the df_fip dataframe to use | ||
| average(bool), whether to return the average, or all individual traces | ||
| tw, time window before and after each event | ||
| censor, censor important timepoints before and after aligned timepoints | ||
| censor_times, timepoints to censor | ||
| data_column (string), name of data column in nwb.df_fip | ||
|
|
||
| """ | ||
| etr_list = [] | ||
| for (i, nwb) in enumerate(nwbs_list): | ||
| data = nwb.df_fip.query("event == @channel") | ||
| etr = an.event_triggered_response( | ||
| data, | ||
| "timestamps", | ||
| data_column, | ||
| align_timepoints[i], | ||
| t_start=tw[0], | ||
| t_end=tw[1], | ||
| output_sampling_rate=40, | ||
| censor=censor, | ||
| censor_times=censor_times, | ||
| ) | ||
| etr['ses_idx'] = data.ses_idx.values[0] | ||
| etr_list.append(etr) | ||
|
|
||
| etr_all = pd.concat(etr_list, axis=0).reset_index(drop=True) | ||
| if average: | ||
| # Average within each ses_idx for each time point | ||
| mean_per_ses = etr_all.groupby(['ses_idx', 'time'])[data_column].mean().unstack('ses_idx') | ||
| # Grand mean: average across ses_idx for each time point | ||
| grand_mean = mean_per_ses.mean(axis=1) | ||
| # SEM over ses_idx for each time point | ||
| grand_sem = mean_per_ses.sem(axis=1) | ||
| # Combine into a DataFrame | ||
| result = grand_mean.to_frame(name=data_column) | ||
| result['sem'] = grand_sem | ||
| return result | ||
| return etr_all | ||
|
|
||
|
|
||
| def fip_psth_inner_compute( | ||
| nwb, | ||
| align_timepoints, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why check only the first NWB? Checks are fast, so we should just check all of them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. done. am testing and wll send PR back to you once I do that.