@@ -24,7 +24,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
2424):
2525 """
2626 Compare the same FIP channel aligned to multiple event types
27- nwb, nwb object for the session
27+ nwb, nwb object for the session, or a list of nwbs
2828 alignments, either a list of event types in df_events, or a dictionary
2929 whose keys are event types and values are a list of timepoints
3030 channel, (str) the name of the FIP channel
@@ -39,28 +39,43 @@ def plot_fip_psth_compare_alignments( # NOQA C901
3939 plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time',
4040 'right_reward_delivery_time'],'G_1_preprocessed')
4141 """
42- if not hasattr (nwb , "df_fip" ):
43- print ("You need to compute the df_fip first" )
44- print ("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`" )
45- nwb .df_fip = nu .create_fib_df (nwb , tidy = True )
46- if not hasattr (nwb , "df_events" ):
47- print ("You need to compute the df_events first" )
48- print ("run `nwb.df_events = create_events_df(nwb)`" )
49- nwb .df_events = nu .create_events_df (nwb )
50-
51- if channel not in nwb .df_fip ["event" ].values :
52- print ("channel {} not in df_fip" .format (channel ))
53-
54- if isinstance (alignments , list ):
55- align_dict = {}
56- for a in alignments :
57- if a not in nwb .df_events ["event" ].values :
58- print ("{} not found in the events table" .format (a ))
59- return
60- else :
61- align_dict [a ] = nwb .df_events .query ("event == @a" )["timestamps" ].values
62- elif isinstance (alignments , dict ):
63- align_dict = alignments
42+
43+ nwb_list = nwb if isinstance (nwb , list ) else [nwb ]
44+ for nwb_i in nwb_list :
45+ if not hasattr (nwb_i , "df_fip" ):
46+ print ("You need to compute the df_fip first" )
47+ print ("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`" )
48+ nwb_i .df_fip = nu .create_fib_df (nwb_i , tidy = True )
49+ if not hasattr (nwb_i , "df_events" ):
50+ print ("You need to compute the df_events first" )
51+ print ("run `nwb.df_events = create_events_df(nwb)`" )
52+ nwb_i .df_events = nu .create_events_df (nwb_i )
53+ if channel not in nwb_i .df_fip ["event" ].values :
54+ print ("channel {} not in df_fip" .format (channel ))
55+
56+ # if single nwb - can pass list, or dictionary
57+ # if list of nwbs - can pass a single list, or list of dictionaries
58+ if len (nwb_list ) == 1 and not (isinstance (alignments , list ) or isinstance (alignments , dict )):
59+ raise Exception ("Must pass alignments as a list of events, or a dictionary of times" )
60+ elif len (nwb_list ) > 1 and (not isinstance (alignments , list )):
61+ raise Exception (
62+ "Must pass alignments as a list of events, or a list of dictionariesof times"
63+ )
64+
65+ if isinstance (alignments , dict ):
66+ # We have a single NWB, given a dictionary of alignments, make it a list and we are done
67+ align_list = [alignments ]
68+ elif isinstance (alignments , list ):
69+ align_list = []
70+ for i , nwb_i in enumerate (nwb_list ):
71+ align_dict = {}
72+ for a in alignments :
73+ if a not in nwb_i .df_events ["event" ].values :
74+ print ("{} not found in the events table: {}" .format (a , nwb_i .session_id ))
75+ return
76+ else :
77+ align_dict [a ] = nwb_i .df_events .query ("event == @a" )["timestamps" ].values
78+ align_list .append (align_dict )
6479 else :
6580 print (
6681 "alignments must be either a list of events in nwb.df_events, "
@@ -69,21 +84,40 @@ def plot_fip_psth_compare_alignments( # NOQA C901
6984 )
7085 return
7186
72- censor_times = []
73- for key in align_dict :
74- censor_times .append (align_dict [key ])
75- censor_times = np .sort (np .concatenate (censor_times ))
87+ # Compute censor times
88+ censor_times_list = []
89+ for i , nwb_i in enumerate (nwb_list ):
90+ censor_times = []
91+ for key in align_list [i ]:
92+ censor_times .append (align_list [i ][key ])
93+ censor_times = np .sort (np .concatenate (censor_times ))
94+ censor_times_list .append (censor_times )
7695
77- align_label = "Time (s)"
96+ # Create figure if not supplied
7897 if fig is None and ax is None :
7998 fig , ax = plt .subplots ()
8099
100+ # Get colors
81101 colors = {** FIP_COLORS , ** extra_colors }
82102
83- for alignment in align_dict :
84- etr = fip_psth_inner_compute (
85- nwb , align_dict [alignment ], channel , True , tw , censor , censor_times , data_column
86- )
103+ align_label = "Time (s)"
104+ for alignment in align_list [0 ]:
105+ if len (nwb_list ) == 1 :
106+ etr = fip_psth_inner_compute (
107+ nwb_list [0 ],
108+ align_list [0 ][alignment ],
109+ channel ,
110+ True ,
111+ tw ,
112+ censor ,
113+ censor_times_list [0 ],
114+ data_column ,
115+ )
116+ else :
117+ this_align = [x [alignment ] for x in align_list ]
118+ etr = fip_psth_multiple_inner_compute (
119+ nwb_list , this_align , channel , True , tw , censor , censor_times_list , data_column
120+ )
87121 fip_psth_inner_plot (ax , etr , colors .get (alignment , "" ), alignment , data_column )
88122
89123 plt .legend ()
@@ -94,7 +128,10 @@ def plot_fip_psth_compare_alignments( # NOQA C901
94128 ax .set_xlim (tw )
95129 ax .axvline (0 , color = "k" , alpha = 0.2 )
96130 ax .tick_params (axis = "both" , labelsize = STYLE ["axis_ticks_fontsize" ])
97- ax .set_title (nwb .session_id , fontsize = STYLE ["axis_fontsize" ])
131+ if len (nwb_list ) == 1 :
132+ ax .set_title (nwb_list [0 ].session_id , fontsize = STYLE ["axis_fontsize" ])
133+ else :
134+ ax .set_title ("{} sessions" .format (len (nwb_list )), fontsize = STYLE ["axis_fontsize" ])
98135 plt .tight_layout ()
99136 return fig , ax
100137
@@ -176,16 +213,28 @@ def plot_fip_psth_compare_channels(
176213 # Iterate through channels and plot
177214 colors = [FIP_COLORS .get (c , "" ) for c in channels ]
178215 for dex , c in enumerate (channels ):
179- include = [c in nwb .df_fip ["event" ].values for nwb in nwb_list ]
180- etr = fip_psth_multiple_inner_compute (
181- [x for dex , x in enumerate (nwb_list ) if include [dex ]],
182- [x for dex , x in enumerate (align_timepoints_list ) if include [dex ]],
183- c ,
184- True ,
185- tw ,
186- censor ,
187- data_column = data_column ,
188- )
216+ if len (nwb_list ) == 1 :
217+ if c in nwb_list [0 ].df_fip ["event" ].values :
218+ etr = fip_psth_inner_compute (
219+ nwb_list [0 ],
220+ align_timepoints_list [0 ],
221+ c ,
222+ True ,
223+ tw ,
224+ censor ,
225+ data_column = data_column ,
226+ )
227+ else :
228+ include = [c in nwb .df_fip ["event" ].values for nwb in nwb_list ]
229+ etr = fip_psth_multiple_inner_compute (
230+ [x for dex , x in enumerate (nwb_list ) if include [dex ]],
231+ [x for dex , x in enumerate (align_timepoints_list ) if include [dex ]],
232+ c ,
233+ True ,
234+ tw ,
235+ censor ,
236+ data_column = data_column ,
237+ )
189238 fip_psth_inner_plot (ax , etr , colors [dex ], c , data_column )
190239
191240 plt .legend ()
0 commit comments