Skip to content

Commit 6e9c7e3

Browse files
committed
adding error_type cleaning up, linting
1 parent 2d2de43 commit 6e9c7e3

1 file changed

Lines changed: 36 additions & 47 deletions

File tree

  • src/aind_dynamic_foraging_basic_analysis/plot

src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
"""
44

55
import matplotlib.pyplot as plt
6-
import pandas as pd
76
import numpy as np
8-
7+
import pandas as pd
98
from aind_dynamic_foraging_data_utils import alignment as an
109
from aind_dynamic_foraging_data_utils import nwb_utils as nu
11-
from aind_dynamic_foraging_basic_analysis.plot.style import STYLE, FIP_COLORS
10+
11+
from aind_dynamic_foraging_basic_analysis.plot.style import FIP_COLORS, STYLE
1212

1313

1414
def plot_fip_psth_compare_alignments( # NOQA C901
@@ -21,6 +21,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
2121
censor=True,
2222
extra_colors={},
2323
data_column="data",
24+
error_type="sem",
2425
):
2526
"""
2627
Compare the same FIP channel aligned to multiple event types
@@ -39,6 +40,8 @@ def plot_fip_psth_compare_alignments( # NOQA C901
3940
plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time',
4041
'right_reward_delivery_time'],'G_1_preprocessed')
4142
"""
43+
if error_type not in ["sem", "sem_over_sessions"]:
44+
raise Exception("unknown error type")
4245

4346
nwb_list = nwb if isinstance(nwb, list) else [nwb]
4447
for nwb_i in nwb_list:
@@ -102,23 +105,11 @@ def plot_fip_psth_compare_alignments( # NOQA C901
102105

103106
align_label = "Time (s)"
104107
for alignment in align_list[0]:
105-
if len(nwb_list) == 1:
106-
etr = fip_psth_inner_compute(
107-
nwb_list[0],
108-
align_list[0][alignment],
109-
channel,
110-
True,
111-
tw,
112-
censor,
113-
censor_times_list[0],
114-
data_column,
115-
)
116-
else:
117-
this_align = [x[alignment] for x in align_list]
118-
etr = fip_psth_multiple_inner_compute(
119-
nwb_list, this_align, channel, True, tw, censor, censor_times_list, data_column
120-
)
121-
fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column)
108+
this_align = [x[alignment] for x in align_list]
109+
etr = fip_psth_multiple_inner_compute(
110+
nwb_list, this_align, channel, True, tw, censor, censor_times_list, data_column
111+
)
112+
fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column, error_type)
122113

123114
plt.legend()
124115
ax.set_xlabel(align_label, fontsize=STYLE["axis_fontsize"])
@@ -152,6 +143,7 @@ def plot_fip_psth_compare_channels(
152143
],
153144
censor=True,
154145
data_column="data",
146+
error_type="sem",
155147
):
156148
"""
157149
nwb, the nwb object for the session of interest, or a list of nwb objects
@@ -168,6 +160,10 @@ def plot_fip_psth_compare_channels(
168160
plot_fip_psth(nwb_list, 'goCue_start_time')
169161
plot_fip_psth(nwb_list, ['goCue_start_time','goCue_start_time'])
170162
"""
163+
164+
if error_type not in ["sem", "sem_over_sessions"]:
165+
raise Exception("Unknown error type")
166+
171167
# Check if nwb is a list, otherwise put it in a list to check
172168
nwb_list = nwb if isinstance(nwb, list) else [nwb]
173169

@@ -213,29 +209,17 @@ def plot_fip_psth_compare_channels(
213209
# Iterate through channels and plot
214210
colors = [FIP_COLORS.get(c, "") for c in channels]
215211
for dex, c in enumerate(channels):
216-
if len(nwb_list) == 1:
217-
if c in nwb_list[0].df_fip["event"].values:
218-
etr = fip_psth_inner_compute(
219-
nwb_list[0],
220-
align_timepoints_list[0],
221-
c,
222-
True,
223-
tw,
224-
censor,
225-
data_column=data_column,
226-
)
227-
else:
228-
include = [c in nwb.df_fip["event"].values for nwb in nwb_list]
229-
etr = fip_psth_multiple_inner_compute(
230-
[x for dex, x in enumerate(nwb_list) if include[dex]],
231-
[x for dex, x in enumerate(align_timepoints_list) if include[dex]],
232-
c,
233-
True,
234-
tw,
235-
censor,
236-
data_column=data_column,
237-
)
238-
fip_psth_inner_plot(ax, etr, colors[dex], c, data_column)
212+
include = [c in nwb.df_fip["event"].values for nwb in nwb_list]
213+
etr = fip_psth_multiple_inner_compute(
214+
[x for dex, x in enumerate(nwb_list) if include[dex]],
215+
[x for dex, x in enumerate(align_timepoints_list) if include[dex]],
216+
c,
217+
True,
218+
tw,
219+
censor,
220+
data_column=data_column,
221+
)
222+
fip_psth_inner_plot(ax, etr, colors[dex], c, data_column, error_type)
239223

240224
plt.legend()
241225
ax.set_xlabel(align_label, fontsize=STYLE["axis_fontsize"])
@@ -253,22 +237,23 @@ def plot_fip_psth_compare_channels(
253237
return fig, ax
254238

255239

256-
def fip_psth_inner_plot(ax, etr, color, label, data_column):
240+
def fip_psth_inner_plot(ax, etr, color, label, data_column, error_type="sem"):
257241
"""
258242
helper function that plots an event triggered response
259243
ax, the pyplot axis to plot on
260244
etr, the dataframe that contains the event triggered response
261245
color, the line color to plot
262246
label, the label for the etr
263247
data_column (string), name of data_column
248+
error_type, the error bar type to plot, must be a column in etr
264249
"""
265250
if color == "":
266251
cmap = plt.get_cmap("tab20")
267252
color = cmap(np.random.randint(20))
268253
ax.fill_between(
269254
etr.index,
270-
etr[data_column] - etr["sem"],
271-
etr[data_column] + etr["sem"],
255+
etr[data_column] - etr[error_type],
256+
etr[data_column] + etr[error_type],
272257
color=color,
273258
alpha=0.2,
274259
)
@@ -322,7 +307,11 @@ def fip_psth_multiple_inner_compute(
322307
grand_sem = mean_per_ses.sem(axis=1)
323308
# Combine into a DataFrame
324309
result = grand_mean.to_frame(name=data_column)
325-
result["sem"] = grand_sem
310+
result["sem_over_sessions"] = grand_sem
311+
312+
# Compute SEM collapsing over sessions
313+
result["sem"] = etr_all.groupby("time")[data_column].sem()
314+
326315
return result
327316
else:
328317
return etr_all

0 commit comments

Comments
 (0)