33"""
44
55import matplotlib .pyplot as plt
6- import pandas as pd
76import numpy as np
8-
7+ import pandas as pd
98from aind_dynamic_foraging_data_utils import alignment as an
109from aind_dynamic_foraging_data_utils import nwb_utils as nu
11- from aind_dynamic_foraging_basic_analysis .plot .style import STYLE , FIP_COLORS
10+
11+ from aind_dynamic_foraging_basic_analysis .plot .style import FIP_COLORS , STYLE
1212
1313
1414def plot_fip_psth_compare_alignments ( # NOQA C901
@@ -21,6 +21,7 @@ def plot_fip_psth_compare_alignments( # NOQA C901
2121 censor = True ,
2222 extra_colors = {},
2323 data_column = "data" ,
24+ error_type = "sem" ,
2425):
2526 """
2627 Compare the same FIP channel aligned to multiple event types
@@ -39,6 +40,8 @@ def plot_fip_psth_compare_alignments( # NOQA C901
3940 plot_fip_psth_compare_alignments(nwb,['left_reward_delivery_time',
4041 'right_reward_delivery_time'],'G_1_preprocessed')
4142 """
43+ if error_type not in ["sem" , "sem_over_sessions" ]:
44+ raise Exception ("unknown error type" )
4245
4346 nwb_list = nwb if isinstance (nwb , list ) else [nwb ]
4447 for nwb_i in nwb_list :
@@ -102,23 +105,11 @@ def plot_fip_psth_compare_alignments( # NOQA C901
102105
103106 align_label = "Time (s)"
104107 for alignment in align_list [0 ]:
105- if len (nwb_list ) == 1 :
106- etr = fip_psth_inner_compute (
107- nwb_list [0 ],
108- align_list [0 ][alignment ],
109- channel ,
110- True ,
111- tw ,
112- censor ,
113- censor_times_list [0 ],
114- data_column ,
115- )
116- else :
117- this_align = [x [alignment ] for x in align_list ]
118- etr = fip_psth_multiple_inner_compute (
119- nwb_list , this_align , channel , True , tw , censor , censor_times_list , data_column
120- )
121- fip_psth_inner_plot (ax , etr , colors .get (alignment , "" ), alignment , data_column )
108+ this_align = [x [alignment ] for x in align_list ]
109+ etr = fip_psth_multiple_inner_compute (
110+ nwb_list , this_align , channel , True , tw , censor , censor_times_list , data_column
111+ )
112+ fip_psth_inner_plot (ax , etr , colors .get (alignment , "" ), alignment , data_column , error_type )
122113
123114 plt .legend ()
124115 ax .set_xlabel (align_label , fontsize = STYLE ["axis_fontsize" ])
@@ -152,6 +143,7 @@ def plot_fip_psth_compare_channels(
152143 ],
153144 censor = True ,
154145 data_column = "data" ,
146+ error_type = "sem" ,
155147):
156148 """
157149 nwb, the nwb object for the session of interest, or a list of nwb objects
@@ -168,6 +160,10 @@ def plot_fip_psth_compare_channels(
168160 plot_fip_psth(nwb_list, 'goCue_start_time')
169161 plot_fip_psth(nwb_list, ['goCue_start_time','goCue_start_time'])
170162 """
163+
164+ if error_type not in ["sem" , "sem_over_sessions" ]:
165+ raise Exception ("Unknown error type" )
166+
171167 # Check if nwb is a list, otherwise put it in a list to check
172168 nwb_list = nwb if isinstance (nwb , list ) else [nwb ]
173169
@@ -213,29 +209,17 @@ def plot_fip_psth_compare_channels(
213209 # Iterate through channels and plot
214210 colors = [FIP_COLORS .get (c , "" ) for c in channels ]
215211 for dex , c in enumerate (channels ):
216- if len (nwb_list ) == 1 :
217- if c in nwb_list [0 ].df_fip ["event" ].values :
218- etr = fip_psth_inner_compute (
219- nwb_list [0 ],
220- align_timepoints_list [0 ],
221- c ,
222- True ,
223- tw ,
224- censor ,
225- data_column = data_column ,
226- )
227- else :
228- include = [c in nwb .df_fip ["event" ].values for nwb in nwb_list ]
229- etr = fip_psth_multiple_inner_compute (
230- [x for dex , x in enumerate (nwb_list ) if include [dex ]],
231- [x for dex , x in enumerate (align_timepoints_list ) if include [dex ]],
232- c ,
233- True ,
234- tw ,
235- censor ,
236- data_column = data_column ,
237- )
238- fip_psth_inner_plot (ax , etr , colors [dex ], c , data_column )
212+ include = [c in nwb .df_fip ["event" ].values for nwb in nwb_list ]
213+ etr = fip_psth_multiple_inner_compute (
214+ [x for dex , x in enumerate (nwb_list ) if include [dex ]],
215+ [x for dex , x in enumerate (align_timepoints_list ) if include [dex ]],
216+ c ,
217+ True ,
218+ tw ,
219+ censor ,
220+ data_column = data_column ,
221+ )
222+ fip_psth_inner_plot (ax , etr , colors [dex ], c , data_column , error_type )
239223
240224 plt .legend ()
241225 ax .set_xlabel (align_label , fontsize = STYLE ["axis_fontsize" ])
@@ -253,22 +237,23 @@ def plot_fip_psth_compare_channels(
253237 return fig , ax
254238
255239
256- def fip_psth_inner_plot (ax , etr , color , label , data_column ):
240+ def fip_psth_inner_plot (ax , etr , color , label , data_column , error_type = "sem" ):
257241 """
258242 helper function that plots an event triggered response
259243 ax, the pyplot axis to plot on
260244 etr, the dataframe that contains the event triggered response
261245 color, the line color to plot
262246 label, the label for the etr
263247 data_column (string), name of data_column
248+ error_type, the error bar type to plot, must be a column in etr
264249 """
265250 if color == "" :
266251 cmap = plt .get_cmap ("tab20" )
267252 color = cmap (np .random .randint (20 ))
268253 ax .fill_between (
269254 etr .index ,
270- etr [data_column ] - etr ["sem" ],
271- etr [data_column ] + etr ["sem" ],
255+ etr [data_column ] - etr [error_type ],
256+ etr [data_column ] + etr [error_type ],
272257 color = color ,
273258 alpha = 0.2 ,
274259 )
@@ -322,7 +307,11 @@ def fip_psth_multiple_inner_compute(
322307 grand_sem = mean_per_ses .sem (axis = 1 )
323308 # Combine into a DataFrame
324309 result = grand_mean .to_frame (name = data_column )
325- result ["sem" ] = grand_sem
310+ result ["sem_over_sessions" ] = grand_sem
311+
312+ # Compute SEM collapsing over sessions
313+ result ["sem" ] = etr_all .groupby ("time" )[data_column ].sem ()
314+
326315 return result
327316 else :
328317 return etr_all
0 commit comments