Skip to content

Commit 2d2de43

Browse files
committed
dev
1 parent 69c811c commit 2d2de43

1 file changed

Lines changed: 92 additions & 43 deletions

File tree

  • src/aind_dynamic_foraging_basic_analysis/plot

src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py

Lines changed: 92 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
2424
):
2525
"""
2626
Compare the same FIP channel aligned to multiple event types
27-
nwb, nwb object for the session
27+
nwb, nwb object for the session, or a list of nwbs
2828
alignments, either a list of event types in df_events, or a dictionary
2929
whose keys are event types and values are a list of timepoints
3030
channel, (str) the name of the FIP channel
@@ -39,28 +39,43 @@ def plot_fip_psth_compare_alignments( # NOQA C901
3939
plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time',
4040
'right_reward_delivery_time'],'G_1_preprocessed')
4141
"""
42-
if not hasattr(nwb, "df_fip"):
43-
print("You need to compute the df_fip first")
44-
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
45-
nwb.df_fip = nu.create_fib_df(nwb, tidy=True)
46-
if not hasattr(nwb, "df_events"):
47-
print("You need to compute the df_events first")
48-
print("run `nwb.df_events = create_events_df(nwb)`")
49-
nwb.df_events = nu.create_events_df(nwb)
50-
51-
if channel not in nwb.df_fip["event"].values:
52-
print("channel {} not in df_fip".format(channel))
53-
54-
if isinstance(alignments, list):
55-
align_dict = {}
56-
for a in alignments:
57-
if a not in nwb.df_events["event"].values:
58-
print("{} not found in the events table".format(a))
59-
return
60-
else:
61-
align_dict[a] = nwb.df_events.query("event == @a")["timestamps"].values
62-
elif isinstance(alignments, dict):
63-
align_dict = alignments
42+
43+
nwb_list = nwb if isinstance(nwb, list) else [nwb]
44+
for nwb_i in nwb_list:
45+
if not hasattr(nwb_i, "df_fip"):
46+
print("You need to compute the df_fip first")
47+
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
48+
nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True)
49+
if not hasattr(nwb_i, "df_events"):
50+
print("You need to compute the df_events first")
51+
print("run `nwb.df_events = create_events_df(nwb)`")
52+
nwb_i.df_events = nu.create_events_df(nwb_i)
53+
if channel not in nwb_i.df_fip["event"].values:
54+
print("channel {} not in df_fip".format(channel))
55+
56+
# if single nwb - can pass list, or dictionary
57+
# if list of nwbs - can pass a single list, or list of dictionaries
58+
if len(nwb_list) == 1 and not (isinstance(alignments, list) or isinstance(alignments, dict)):
59+
raise Exception("Must pass alignments as a list of events, or a dictionary of times")
60+
elif len(nwb_list) > 1 and (not isinstance(alignments, list)):
61+
raise Exception(
62+
"Must pass alignments as a list of events, or a list of dictionariesof times"
63+
)
64+
65+
if isinstance(alignments, dict):
66+
# We have a single NWB, given a dictionary of alignments, make it a list and we are done
67+
align_list = [alignments]
68+
elif isinstance(alignments, list):
69+
align_list = []
70+
for i, nwb_i in enumerate(nwb_list):
71+
align_dict = {}
72+
for a in alignments:
73+
if a not in nwb_i.df_events["event"].values:
74+
print("{} not found in the events table: {}".format(a, nwb_i.session_id))
75+
return
76+
else:
77+
align_dict[a] = nwb_i.df_events.query("event == @a")["timestamps"].values
78+
align_list.append(align_dict)
6479
else:
6580
print(
6681
"alignments must be either a list of events in nwb.df_events, "
@@ -69,21 +84,40 @@ def plot_fip_psth_compare_alignments( # NOQA C901
6984
)
7085
return
7186

72-
censor_times = []
73-
for key in align_dict:
74-
censor_times.append(align_dict[key])
75-
censor_times = np.sort(np.concatenate(censor_times))
87+
# Compute censor times
88+
censor_times_list = []
89+
for i, nwb_i in enumerate(nwb_list):
90+
censor_times = []
91+
for key in align_list[i]:
92+
censor_times.append(align_list[i][key])
93+
censor_times = np.sort(np.concatenate(censor_times))
94+
censor_times_list.append(censor_times)
7695

77-
align_label = "Time (s)"
96+
# Create figure if not supplied
7897
if fig is None and ax is None:
7998
fig, ax = plt.subplots()
8099

100+
# Get colors
81101
colors = {**FIP_COLORS, **extra_colors}
82102

83-
for alignment in align_dict:
84-
etr = fip_psth_inner_compute(
85-
nwb, align_dict[alignment], channel, True, tw, censor, censor_times, data_column
86-
)
103+
align_label = "Time (s)"
104+
for alignment in align_list[0]:
105+
if len(nwb_list) == 1:
106+
etr = fip_psth_inner_compute(
107+
nwb_list[0],
108+
align_list[0][alignment],
109+
channel,
110+
True,
111+
tw,
112+
censor,
113+
censor_times_list[0],
114+
data_column,
115+
)
116+
else:
117+
this_align = [x[alignment] for x in align_list]
118+
etr = fip_psth_multiple_inner_compute(
119+
nwb_list, this_align, channel, True, tw, censor, censor_times_list, data_column
120+
)
87121
fip_psth_inner_plot(ax, etr, colors.get(alignment, ""), alignment, data_column)
88122

89123
plt.legend()
@@ -94,7 +128,10 @@ def plot_fip_psth_compare_alignments( # NOQA C901
94128
ax.set_xlim(tw)
95129
ax.axvline(0, color="k", alpha=0.2)
96130
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
97-
ax.set_title(nwb.session_id, fontsize=STYLE["axis_fontsize"])
131+
if len(nwb_list) == 1:
132+
ax.set_title(nwb_list[0].session_id, fontsize=STYLE["axis_fontsize"])
133+
else:
134+
ax.set_title("{} sessions".format(len(nwb_list)), fontsize=STYLE["axis_fontsize"])
98135
plt.tight_layout()
99136
return fig, ax
100137

@@ -176,16 +213,28 @@ def plot_fip_psth_compare_channels(
176213
# Iterate through channels and plot
177214
colors = [FIP_COLORS.get(c, "") for c in channels]
178215
for dex, c in enumerate(channels):
179-
include = [c in nwb.df_fip["event"].values for nwb in nwb_list]
180-
etr = fip_psth_multiple_inner_compute(
181-
[x for dex, x in enumerate(nwb_list) if include[dex]],
182-
[x for dex, x in enumerate(align_timepoints_list) if include[dex]],
183-
c,
184-
True,
185-
tw,
186-
censor,
187-
data_column=data_column,
188-
)
216+
if len(nwb_list) == 1:
217+
if c in nwb_list[0].df_fip["event"].values:
218+
etr = fip_psth_inner_compute(
219+
nwb_list[0],
220+
align_timepoints_list[0],
221+
c,
222+
True,
223+
tw,
224+
censor,
225+
data_column=data_column,
226+
)
227+
else:
228+
include = [c in nwb.df_fip["event"].values for nwb in nwb_list]
229+
etr = fip_psth_multiple_inner_compute(
230+
[x for dex, x in enumerate(nwb_list) if include[dex]],
231+
[x for dex, x in enumerate(align_timepoints_list) if include[dex]],
232+
c,
233+
True,
234+
tw,
235+
censor,
236+
data_column=data_column,
237+
)
189238
fip_psth_inner_plot(ax, etr, colors[dex], c, data_column)
190239

191240
plt.legend()

0 commit comments

Comments
 (0)