Skip to content

Commit b9a98c5

Browse files
committed
fixing review issues
1 parent 131bec3 commit b9a98c5

1 file changed

Lines changed: 21 additions & 5 deletions

File tree

  • src/aind_dynamic_foraging_basic_analysis/plot

src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def plot_fip_psth_compare_alignments( # NOQA C901
3434
extra_colors (dict), a dictionary of extra colors.
3535
keys should be alignments, or colors are random
3636
data_column (string), name of data column in nwb.df_fip
37+
error_type, (string), either "sem" or "sem_over_sessions" to define
38+
the error bar for the PSTH
3739
3840
EXAMPLE
3941
*******************
@@ -44,6 +46,9 @@ def plot_fip_psth_compare_alignments( # NOQA C901
4446
raise Exception("unknown error type")
4547

4648
nwb_list = nwb if isinstance(nwb, list) else [nwb]
49+
if len(nwb_list) == 1 and error_type == "sem_over_sessions":
50+
raise Exception("Cannot have sem_over_sessions with one session")
51+
4752
for nwb_i in nwb_list:
4853
if not hasattr(nwb_i, "df_fip"):
4954
print("You need to compute the df_fip first")
@@ -84,8 +89,9 @@ def plot_fip_psth_compare_alignments( # NOQA C901
8489
else:
8590
print(
8691
"alignments must be either a list of events in nwb.df_events, "
87-
+ "or a dictionary where each key is an event type, "
88-
+ "and the value is a list of timepoints"
92+
+ "or, for a single session, a dictionary where each key is an event type, "
93+
+ "and the value is a list of timepoints. If multiple sessions are given, "
94+
+ "you may pass a list of dictionaries"
8995
)
9096
return
9197

@@ -151,16 +157,18 @@ def plot_fip_psth_compare_channels( # NOQA C901
151157
nwb, the nwb object for the session of interest, or a list of nwb objects
152158
align should either be a string of the name of an event type in nwb.df_events,
153159
or a list of timepoints. if nwb is a list, then align should be a list containing
154-
either the string of the name of an event type, or a list of timepoints.
160+
lists of timepoints for each session.
155161
channels should be a list of channel names (strings)
156162
censor, censor important timepoints before and after aligned timepoints
157163
data_column (string), name of data column in nwb.df_fip
164+
error_type, (string), either "sem" or "sem_over_sessions" to define
165+
the error bar for the PSTH
158166
159167
EXAMPLE
160168
********************
161169
plot_fip_psth(nwb, 'goCue_start_time')
162170
plot_fip_psth(nwb_list, 'goCue_start_time')
163-
plot_fip_psth(nwb_list, ['goCue_start_time','goCue_start_time'])
171+
plot_fip_psth(nwb_list, [session_1_timepoints, session_2_timepoints, ... ])
164172
"""
165173

166174
if error_type not in ["sem", "sem_over_sessions"]:
@@ -169,8 +177,16 @@ def plot_fip_psth_compare_channels( # NOQA C901
169177
# Check if nwb is a list, otherwise put it in a list to check
170178
nwb_list = nwb if isinstance(nwb, list) else [nwb]
171179

180+
if len(nwb_list) == 1 and error_type == "sem_over_sessions":
181+
raise Exception("Cannot have sem_over_sessions with one session")
172182
if isinstance(nwb, list) and isinstance(align, list) and (len(nwb) != len(align)):
173183
raise Exception("NWB list and align list must match")
184+
if (
185+
isinstance(nwb, list)
186+
and isinstance(align, list)
187+
and not all(isinstance(item, list) or isinstance(item, np.ndarray) for item in align)
188+
):
189+
raise Exception("When using multiple sessions, align must be a list of lists")
174190
if isinstance(nwb, list) and isinstance(align, str):
175191
align = [align] * len(nwb)
176192
if not isinstance(nwb, list):
@@ -247,7 +263,7 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column, error_type="sem"):
247263
color, the line color to plot
248264
label, the label for the etr
249265
data_column (string), name of data_column
250-
error_type, the error bar type to plot, must be a column in etr
266+
error_type, (string), the error bar type to plot, must be a column in etr
251267
"""
252268
if color == "":
253269
cmap = plt.get_cmap("tab20")

0 commit comments

Comments
 (0)