33"""
44
55import matplotlib .pyplot as plt
6+ import pandas as pd
67import numpy as np
78
89from aind_dynamic_foraging_data_utils import alignment as an
@@ -116,47 +117,76 @@ def plot_fip_psth_compare_channels(
116117 data_column = "data" ,
117118):
118119 """
119- nwb, the nwb object for the session of interest
120+ nwb, the nwb object for the session of interest, or a list of nwb objects
120121 align should either be a string of the name of an event type in nwb.df_events,
121- or a list of timepoints
122+ or a list of timepoints. if nwb is a list, then align should be a list containing
123+ either the string of the name of an event type, or a list of timepoints.
122124 channels should be a list of channel names (strings)
123125 censor, censor important timepoints before and after aligned timepoints
124126 data_column (string), name of data column in nwb.df_fip
125127
126128 EXAMPLE
127129 ********************
128130 plot_fip_psth(nwb, 'goCue_start_time')
131+ plot_fip_psth(nwb_list, 'goCue_start_time')
132+ plot_fip_psth(nwb_list, ['goCue_start_time','goCue_start_time'])
129133 """
130- if not hasattr (nwb , "df_fip" ):
131- print ("You need to compute the df_fip first" )
132- print ("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`" )
133- nwb .df_fip = nu .create_fib_df (nwb , tidy = True )
134- if not hasattr (nwb , "df_events" ):
135- print ("You need to compute the df_events first" )
136- print ("run `nwb.df_events = create_events_df(nwb)`" )
137- nwb .df_events = nu .create_events_df (nwb )
134+ # Check if nwb is a list, otherwise put it in a list to check
135+ nwb_list = nwb if isinstance (nwb , list ) else [nwb ]
138136
139- if isinstance (align , str ):
140- if align not in nwb .df_events ["event" ].values :
141- print ("{} not found in the events table" .format (align ))
142- return
143- align_timepoints = nwb .df_events .query ("event == @align" )["timestamps" ].values
144- align_label = "Time from {} (s)" .format (align )
145- else :
146- align_timepoints = align
147- align_label = "Time (s)"
137+ if isinstance (nwb , list ) and isinstance (align , list ) and (len (nwb ) != len (align )):
138+ raise Exception ("NWB list and align list must match" )
139+ if isinstance (nwb , list ) and isinstance (align , str ):
140+ align = [align ] * len (nwb )
141+ if not isinstance (nwb , list ):
142+ align = [align ]
148143
144+ # First check that each session has an events table and fip table
145+ for nwb_i in nwb_list :
146+ if not hasattr (nwb_i , "df_fip" ):
147+ print ("You need to compute the df_fip first" )
148+ print ("running `nwb.df_fip = create_fib_df(nwb,tidy=True)`" )
149+ nwb_i .df_fip = nu .create_fib_df (nwb_i , tidy = True )
150+ if not hasattr (nwb_i , "df_events" ):
151+ print ("You need to compute the df_events first" )
152+ print ("run `nwb.df_events = create_events_df(nwb)`" )
153+ nwb_i .df_events = nu .create_events_df (nwb_i )
154+
155+ align_timepoints_list = []
156+ # Generate the alignment timepoints for each session
157+ for i , nwb_i in enumerate (nwb_list ):
158+ align_i = align [i ]
159+ if isinstance (align_i , str ):
160+ if align_i not in nwb_i .df_events ["event" ].values :
161+ print ("{} not found in the events table, {}" .format (align_i , nwb_i .session_id ))
162+ return
163+
164+ align_timepoints_list .append (
165+ nwb_i .df_events .query ("event == @align" )["timestamps" ].values
166+ )
167+ align_label = "Time from {} (s)" .format (align_i )
168+ else :
169+ align_timepoints_list .append (align_i )
170+ align_label = "Time (s)"
171+
172+ # Make figure if not supplied
149173 if fig is None and ax is None :
150174 fig , ax = plt .subplots ()
151175
176+ # Iterate through channels and plot
152177 colors = [FIP_COLORS .get (c , "" ) for c in channels ]
153178 for dex , c in enumerate (channels ):
154- if c in nwb .df_fip ["event" ].values :
155- etr = fip_psth_inner_compute (nwb , align_timepoints , c , True , tw ,
156- censor , data_column = data_column )
157- fip_psth_inner_plot (ax , etr , colors [dex ], c , data_column )
158- else :
159- print ("No data for channel: {}" .format (c ))
179+ include = [c in nwb .df_fip ["event" ].values for nwb in nwb_list ]
180+ etr = fip_psth_multiple_inner_compute (
181+ [x for dex , x in enumerate (nwb_list ) if include [dex ]],
182+ [x for dex , x in enumerate (align_timepoints_list ) if include [dex ]],
183+ c ,
184+ True ,
185+ tw ,
186+ censor ,
187+ data_column = data_column ,
188+ )
189+ fip_psth_inner_plot (ax , etr , colors [dex ], c , data_column )
160190
161191 plt .legend ()
162192 ax .set_xlabel (align_label , fontsize = STYLE ["axis_fontsize" ])
@@ -166,7 +196,10 @@ def plot_fip_psth_compare_channels(
166196 ax .set_xlim (tw )
167197 ax .axvline (0 , color = "k" , alpha = 0.2 )
168198 ax .tick_params (axis = "both" , labelsize = STYLE ["axis_ticks_fontsize" ])
169- ax .set_title (nwb .session_id )
199+ if len (nwb_list ) == 1 :
200+ ax .set_title (nwb_list [0 ].session_id )
201+ else :
202+ ax .set_title ("{} sessions" .format (len (nwb_list )))
170203 plt .tight_layout ()
171204 return fig , ax
172205
@@ -183,11 +216,69 @@ def fip_psth_inner_plot(ax, etr, color, label, data_column):
183216 if color == "" :
184217 cmap = plt .get_cmap ("tab20" )
185218 color = cmap (np .random .randint (20 ))
186- ax .fill_between (etr .index , etr [data_column ] - etr ["sem" ],
187- etr [data_column ] + etr ["sem" ], color = color , alpha = 0.2 )
219+ ax .fill_between (
220+ etr .index ,
221+ etr [data_column ] - etr ["sem" ],
222+ etr [data_column ] + etr ["sem" ],
223+ color = color ,
224+ alpha = 0.2 ,
225+ )
188226 ax .plot (etr .index , etr [data_column ], color = color , label = label )
189227
190228
229+ def fip_psth_multiple_inner_compute (
230+ nwb_list ,
231+ align_timepoints_list ,
232+ channel ,
233+ average ,
234+ tw = [- 1 , 1 ],
235+ censor = True ,
236+ censor_times = None ,
237+ data_column = "data" ,
238+ ):
239+ """ """
240+ # Check that len(nwb_list) = len(align_timepoints_list) = len(censor_times)
241+ if len (nwb_list ) != len (align_timepoints_list ):
242+ raise Exception ("length of nwb list and alignments list must match" )
243+ if censor and censor_times is None :
244+ censor_times = [None ] * len (nwb_list )
245+ if censor and (len (nwb_list ) != len (censor_times )):
246+ raise Exception ("length of nwb list and censor times must match" )
247+
248+ etr_list = []
249+ # Iterate through list of sessions, computing the etr for each
250+ for i , nwb_i in enumerate (nwb_list ):
251+ etr_i = fip_psth_inner_compute (
252+ nwb_i ,
253+ align_timepoints_list [i ],
254+ channel ,
255+ average = False ,
256+ tw = tw ,
257+ censor = censor ,
258+ censor_times = censor_times [i ],
259+ data_column = data_column ,
260+ )
261+ etr_i ["ses_idx" ] = nwb_i .session_id
262+ etr_list .append (etr_i )
263+
264+ # Concat etrs from each session into one dataframe
265+ etr_all = pd .concat (etr_list , axis = 0 ).reset_index (drop = True )
266+
267+ if average :
268+ # Average within each ses_idx for each time point
269+ mean_per_ses = etr_all .groupby (["ses_idx" , "time" ])[data_column ].mean ().unstack ("ses_idx" )
270+ # Grand mean: average across ses_idx for each time point
271+ grand_mean = mean_per_ses .mean (axis = 1 )
272+ # SEM over ses_idx for each time point
273+ grand_sem = mean_per_ses .sem (axis = 1 )
274+ # Combine into a DataFrame
275+ result = grand_mean .to_frame (name = data_column )
276+ result ["sem" ] = grand_sem
277+ return result
278+ else :
279+ return etr_all
280+
281+
191282def fip_psth_inner_compute (
192283 nwb ,
193284 align_timepoints ,
0 commit comments