Skip to content

Commit 69c811c

Browse files
committed
compare_channels
1 parent 08ae615 commit 69c811c

1 file changed

Lines changed: 119 additions & 28 deletions

File tree

  • src/aind_dynamic_foraging_basic_analysis/plot

src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py

Lines changed: 119 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import matplotlib.pyplot as plt
6+
import pandas as pd
67
import numpy as np
78

89
from aind_dynamic_foraging_data_utils import alignment as an
@@ -116,47 +117,76 @@ def plot_fip_psth_compare_channels(
116117
data_column="data",
117118
):
118119
"""
119-
nwb, the nwb object for the session of interest
120+
nwb, the nwb object for the session of interest, or a list of nwb objects
120121
align should either be a string of the name of an event type in nwb.df_events,
121-
or a list of timepoints
122+
or a list of timepoints. if nwb is a list, then align should be a list containing
123+
either the string of the name of an event type, or a list of timepoints.
122124
channels should be a list of channel names (strings)
123125
censor, censor important timepoints before and after aligned timepoints
124126
data_column (string), name of data column in nwb.df_fip
125127
126128
EXAMPLE
127129
********************
128130
plot_fip_psth(nwb, 'goCue_start_time')
131+
plot_fip_psth(nwb_list, 'goCue_start_time')
132+
plot_fip_psth(nwb_list, ['goCue_start_time','goCue_start_time'])
129133
"""
130-
if not hasattr(nwb, "df_fip"):
131-
print("You need to compute the df_fip first")
132-
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
133-
nwb.df_fip = nu.create_fib_df(nwb, tidy=True)
134-
if not hasattr(nwb, "df_events"):
135-
print("You need to compute the df_events first")
136-
print("run `nwb.df_events = create_events_df(nwb)`")
137-
nwb.df_events = nu.create_events_df(nwb)
134+
# Check if nwb is a list, otherwise put it in a list to check
135+
nwb_list = nwb if isinstance(nwb, list) else [nwb]
138136

139-
if isinstance(align, str):
140-
if align not in nwb.df_events["event"].values:
141-
print("{} not found in the events table".format(align))
142-
return
143-
align_timepoints = nwb.df_events.query("event == @align")["timestamps"].values
144-
align_label = "Time from {} (s)".format(align)
145-
else:
146-
align_timepoints = align
147-
align_label = "Time (s)"
137+
if isinstance(nwb, list) and isinstance(align, list) and (len(nwb) != len(align)):
138+
raise Exception("NWB list and align list must match")
139+
if isinstance(nwb, list) and isinstance(align, str):
140+
align = [align] * len(nwb)
141+
if not isinstance(nwb, list):
142+
align = [align]
148143

144+
# First check that each session has an events table and fip table
145+
for nwb_i in nwb_list:
146+
if not hasattr(nwb_i, "df_fip"):
147+
print("You need to compute the df_fip first")
148+
print("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`")
149+
nwb_i.df_fip = nu.create_fib_df(nwb_i, tidy=True)
150+
if not hasattr(nwb_i, "df_events"):
151+
print("You need to compute the df_events first")
152+
print("run `nwb.df_events = create_events_df(nwb)`")
153+
nwb_i.df_events = nu.create_events_df(nwb_i)
154+
155+
align_timepoints_list = []
156+
# Generate the alignment timepoints for each session
157+
for i, nwb_i in enumerate(nwb_list):
158+
align_i = align[i]
159+
if isinstance(align_i, str):
160+
if align_i not in nwb_i.df_events["event"].values:
161+
print("{} not found in the events table, {}".format(align_i, nwb_i.session_id))
162+
return
163+
164+
align_timepoints_list.append(
165+
nwb_i.df_events.query("event == @align")["timestamps"].values
166+
)
167+
align_label = "Time from {} (s)".format(align_i)
168+
else:
169+
align_timepoints_list.append(align_i)
170+
align_label = "Time (s)"
171+
172+
# Make figure if not supplied
149173
if fig is None and ax is None:
150174
fig, ax = plt.subplots()
151175

176+
# Iterate through channels and plot
152177
colors = [FIP_COLORS.get(c, "") for c in channels]
153178
for dex, c in enumerate(channels):
154-
if c in nwb.df_fip["event"].values:
155-
etr = fip_psth_inner_compute(nwb, align_timepoints, c, True, tw,
156-
censor, data_column=data_column)
157-
fip_psth_inner_plot(ax, etr, colors[dex], c, data_column)
158-
else:
159-
print("No data for channel: {}".format(c))
179+
include = [c in nwb.df_fip["event"].values for nwb in nwb_list]
180+
etr = fip_psth_multiple_inner_compute(
181+
[x for dex, x in enumerate(nwb_list) if include[dex]],
182+
[x for dex, x in enumerate(align_timepoints_list) if include[dex]],
183+
c,
184+
True,
185+
tw,
186+
censor,
187+
data_column=data_column,
188+
)
189+
fip_psth_inner_plot(ax, etr, colors[dex], c, data_column)
160190

161191
plt.legend()
162192
ax.set_xlabel(align_label, fontsize=STYLE["axis_fontsize"])
@@ -166,7 +196,10 @@ def plot_fip_psth_compare_channels(
166196
ax.set_xlim(tw)
167197
ax.axvline(0, color="k", alpha=0.2)
168198
ax.tick_params(axis="both", labelsize=STYLE["axis_ticks_fontsize"])
169-
ax.set_title(nwb.session_id)
199+
if len(nwb_list) == 1:
200+
ax.set_title(nwb_list[0].session_id)
201+
else:
202+
ax.set_title("{} sessions".format(len(nwb_list)))
170203
plt.tight_layout()
171204
return fig, ax
172205

@@ -183,11 +216,69 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column):
183216
if color == "":
184217
cmap = plt.get_cmap("tab20")
185218
color = cmap(np.random.randint(20))
186-
ax.fill_between(etr.index, etr[data_column] - etr["sem"],
187-
etr[data_column] + etr["sem"], color=color, alpha=0.2)
219+
ax.fill_between(
220+
etr.index,
221+
etr[data_column] - etr["sem"],
222+
etr[data_column] + etr["sem"],
223+
color=color,
224+
alpha=0.2,
225+
)
188226
ax.plot(etr.index, etr[data_column], color=color, label=label)
189227

190228

229+
def fip_psth_multiple_inner_compute(
230+
nwb_list,
231+
align_timepoints_list,
232+
channel,
233+
average,
234+
tw=[-1, 1],
235+
censor=True,
236+
censor_times=None,
237+
data_column="data",
238+
):
239+
""" """
240+
# Check that len(nwb_list) = len(align_timepoints_list) = len(censor_times)
241+
if len(nwb_list) != len(align_timepoints_list):
242+
raise Exception("length of nwb list and alignments list must match")
243+
if censor and censor_times is None:
244+
censor_times = [None] * len(nwb_list)
245+
if censor and (len(nwb_list) != len(censor_times)):
246+
raise Exception("length of nwb list and censor times must match")
247+
248+
etr_list = []
249+
# Iterate through list of sessions, computing the etr for each
250+
for i, nwb_i in enumerate(nwb_list):
251+
etr_i = fip_psth_inner_compute(
252+
nwb_i,
253+
align_timepoints_list[i],
254+
channel,
255+
average=False,
256+
tw=tw,
257+
censor=censor,
258+
censor_times=censor_times[i],
259+
data_column=data_column,
260+
)
261+
etr_i["ses_idx"] = nwb_i.session_id
262+
etr_list.append(etr_i)
263+
264+
# Concat etrs from each session into one dataframe
265+
etr_all = pd.concat(etr_list, axis=0).reset_index(drop=True)
266+
267+
if average:
268+
# Average within each ses_idx for each time point
269+
mean_per_ses = etr_all.groupby(["ses_idx", "time"])[data_column].mean().unstack("ses_idx")
270+
# Grand mean: average across ses_idx for each time point
271+
grand_mean = mean_per_ses.mean(axis=1)
272+
# SEM over ses_idx for each time point
273+
grand_sem = mean_per_ses.sem(axis=1)
274+
# Combine into a DataFrame
275+
result = grand_mean.to_frame(name=data_column)
276+
result["sem"] = grand_sem
277+
return result
278+
else:
279+
return etr_all
280+
281+
191282
def fip_psth_inner_compute(
192283
nwb,
193284
align_timepoints,

0 commit comments

Comments
 (0)