Skip to content

Commit f933d21

Browse files
committed
dev
1 parent b9a98c5 commit f933d21

1 file changed

Lines changed: 28 additions & 3 deletions

File tree

  • src/aind_dynamic_foraging_basic_analysis/plot

src/aind_dynamic_foraging_basic_analysis/plot/plot_fip.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@ def plot_fip_psth_compare_alignments( # NOQA C901
2626
"""
2727
Compare the same FIP channel aligned to multiple event types
2828
nwb, nwb object for the session, or a list of nwbs
29-
alignments, either a list of event types in df_events, or a dictionary
30-
whose keys are event types and values are a list of timepoints
29+
alignments, with one session alignments can be either a list of
30+
event types in df_events, or a dictionary whose keys are
31+
event types and values are a list of timepoints. With multiple
32+
sessions, alignments can be either a list of event types in df_events,
33+
or a list of dictionaries whose keys are event types and values are a
34+
list of timepoints.
3135
channel, (str) the name of the FIP channel
3236
tw, time window for the PSTH
3337
censor, censor important timepoints before and after aligned timepoints
@@ -67,9 +71,25 @@ def plot_fip_psth_compare_alignments( # NOQA C901
6771
raise Exception("Must pass alignments as a list of events, or a dictionary of times")
6872
elif len(nwb_list) > 1 and (not isinstance(alignments, list)):
6973
raise Exception(
70-
"Must pass alignments as a list of events, or a list of dictionariesof times"
74+
"Must pass alignments as a list of events, or a list of dictionaries of times"
7175
)
7276

77+
# If we are given a list of dictionaries, ensure all dictionaries have the same keys
78+
if (
79+
len(nwb_list) > 1
80+
and isinstance(alignments, list)
81+
and all(isinstance(item, dict) for item in alignments)
82+
):
83+
keys = set()
84+
for d in alignments:
85+
keys.update(list(d.keys()))
86+
for index, d in enumerate(alignments):
87+
missing = keys - set(d.keys())
88+
if len(missing) > 0:
89+
raise Exception(
90+
"{} Missing alignment key: {}".format(nwb_list[index].session_id, list(missing))
91+
)
92+
7393
if isinstance(alignments, dict):
7494
# We have a single NWB, given a dictionary of alignments, make it a list and we are done
7595
align_list = [alignments]
@@ -203,6 +223,11 @@ def plot_fip_psth_compare_channels( # NOQA C901
203223
print("run `nwb.df_events = create_events_df(nwb)`")
204224
nwb_i.df_events = nu.create_events_df(nwb_i)
205225

226+
# Add warning if channels are missing
227+
missing_channels = [c for c in channels if c not in nwb_i.df_fip["event"].values]
228+
if len(missing_channels) > 0:
229+
print("{} missing channel {}".format(nwb_i.session_id, missing_channels))
230+
206231
align_timepoints_list = []
207232
# Generate the alignment timepoints for each session
208233
for i, nwb_i in enumerate(nwb_list):

0 commit comments

Comments
 (0)