@@ -26,8 +26,12 @@ def plot_fip_psth_compare_alignments( # NOQA C901
2626 """
2727 Compare the same FIP channel aligned to multiple event types
2828 nwb, nwb object for the session, or a list of nwbs
29- alignments, either a list of event types in df_events, or a dictionary
30- whose keys are event types and values are a list of timepoints
29+ alignments, with one session alignments can be either a list of
30+ event types in df_events, or a dictionary whose keys are
31+ event types and values are a list of timepoints. With multiple
32+ sessions, alignments can be either a list of event types in df_events,
33+ or a list of dictionaries whose keys are event types and values are a
34+ list of timepoints.
3135 channel, (str) the name of the FIP channel
3236 tw, time window for the PSTH
3337 censor, censor important timepoints before and after aligned timepoints
@@ -67,9 +71,25 @@ def plot_fip_psth_compare_alignments( # NOQA C901
6771 raise Exception ("Must pass alignments as a list of events, or a dictionary of times" )
6872 elif len (nwb_list ) > 1 and (not isinstance (alignments , list )):
6973 raise Exception (
70- "Must pass alignments as a list of events, or a list of dictionariesof times"
74+ "Must pass alignments as a list of events, or a list of dictionaries of times"
7175 )
7276
77+ # If we are given a list of dictionaries, ensure all dictionaries have the same keys
78+ if (
79+ len (nwb_list ) > 1
80+ and isinstance (alignments , list )
81+ and all (isinstance (item , dict ) for item in alignments )
82+ ):
83+ keys = set ()
84+ for d in alignments :
85+ keys .update (list (d .keys ()))
86+ for index , d in enumerate (alignments ):
87+ missing = keys - set (d .keys ())
88+ if len (missing ) > 0 :
89+ raise Exception (
90+ "{} Missing alignment key: {}" .format (nwb_list [index ].session_id , list (missing ))
91+ )
92+
7393 if isinstance (alignments , dict ):
7494 # We have a single NWB, given a dictionary of alignments, make it a list and we are done
7595 align_list = [alignments ]
@@ -203,6 +223,11 @@ def plot_fip_psth_compare_channels( # NOQA C901
203223 print ("run `nwb.df_events = create_events_df(nwb)`" )
204224 nwb_i .df_events = nu .create_events_df (nwb_i )
205225
226+ # Add warning if channels are missing
227+ missing_channels = [c for c in channels if c not in nwb_i .df_fip ["event" ].values ]
228+ if len (missing_channels ) > 0 :
229+ print ("{} missing channel {}" .format (nwb_i .session_id , missing_channels ))
230+
206231 align_timepoints_list = []
207232 # Generate the alignment timepoints for each session
208233 for i , nwb_i in enumerate (nwb_list ):
0 commit comments