11"""
2- Tools for computing per session metrics
3- compute_auroc: compute auroc for one NWB given alignments
4- compute_auroc_multi: compute auroc for multiple NWB given alignments
2+ Tools for computing per session metrics
3+ compute_auroc: compute auroc for one NWB given alignments
4+ compute_auroc_multi: compute auroc for multiple NWB given alignments
55
66"""
77
8-
98from sklearn .metrics import roc_auc_score
109from aind_dynamic_foraging_basic_analysis .plot import plot_fip as pf
1110import warnings
1211import pandas as pd
1312import numpy as np
1413
1514
16- def compute_auroc (nwb , alignment_times , labels , channel , tw , bin_size = 0.25 , data_col = ' data_z' ):
15+ def compute_auroc (nwb , alignment_times , labels , channel , tw , bin_size = 0.25 , data_col = " data_z" ):
1716 """
1817 Compute the time-resolved area under the ROC curve (auROC) for a single NWB session.
1918
2019 Parameters
2120 - nwb: object
22- NWB session object expected to contain a DataFrame `df_fip` with FIP data and a `session_id`.
21+ NWB session object expected to contain a DataFrame `df_fip` with
22+ FIP data and a `session_id`.
2323 - alignment_times: array-like, shape (n_trials,)
2424 Times to align trials to (seconds), given in session time
2525 - labels: array-like, shape (n_trials,)
26- Binary labels (0/1) for each alignment time. Must have same length as alignment_times.
26+ Binary labels (0/1) for each alignment time. Must have same
27+ length as alignment_times.
2728 - channel: str
2829 Channel name to select from `nwb.df_fip.event`.
2930 - tw: tuple (start, end)
30- Time window (seconds) around the alignment to compute auROC over (centered bins will be between tw[0] and tw[1]).
31+ Time window (seconds) around the alignment to compute auROC over
32+ (centered bins will be between tw[0] and tw[1]).
3133 - bin_size: float, optional
32- Width (seconds) of each time bin used to aggregate values before computing auROC. Default 0.25s.
34+ Width (seconds) of each time bin used to aggregate values
35+ before computing auROC. Default 0.25s.
3336 - data_col: str, optional
3437 Column name in the FIP data to use for values (default is z-scored data, 'data_z').
3538
@@ -38,72 +41,72 @@ def compute_auroc(nwb, alignment_times, labels, channel, tw, bin_size = 0.25, da
3841 DataFrame with columns:
3942 - 'bin_center': center time of each bin (seconds)
4043 - 'auc': auROC value for that bin (NaN when computation failed)
41- If the requested channel is not present in the NWB, returns an empty DataFrame with those columns.
44+ If the requested channel is not present in the NWB,
45+ returns an empty DataFrame with those columns.
4246
4347 Notes
4448 - alignment_times and labels are sorted together before computing PSTHs.
45- - Trials with NaNs in the aggregated bin are dropped; event_numbers that contain any NaNs across bins are removed.
49+ - Trials with NaNs in the aggregated bin are dropped;
50+ event_numbers that contain any NaNs across bins are removed.
4651 """
4752 if len (labels ) != len (alignment_times ):
48- raise Exception (' Alignment times must have same number of labels ' )
53+ raise Exception (" Alignment times must have same number of labels " )
4954
5055 if np .unique (labels ).size > 2 :
51- raise Exception (' Labels must be binary for auROC computation' )
52-
56+ raise Exception (" Labels must be binary for auROC computation" )
57+
5358 if channel not in nwb .df_fip .event .unique ():
5459 warnings .warn ("No channel found in this NWB, returning empty DataFrame" )
55- return pd .DataFrame (columns = [' bin_center' , ' auc' ])
60+ return pd .DataFrame (columns = [" bin_center" , " auc" ])
5661
5762 # sort labels and alignment times
5863 sorted_indices = np .argsort (alignment_times )
5964 alignment_times = alignment_times [sorted_indices ]
6065 labels = labels [sorted_indices ]
6166
62- tw_for_center_bin = [tw [0 ] - bin_size / 2 , tw [1 ] + bin_size / 2 ]
67+ tw_for_center_bin = [tw [0 ] - bin_size / 2 , tw [1 ] + bin_size / 2 ]
6368
64- # get alignments
69+ # get alignments
6570 aligns = pf .fip_psth_inner_compute (
66- nwb ,
67- alignment_times ,
68- channel ,
69- average = False ,
70- tw = tw_for_center_bin ,
71- data_column = data_col
72- )
71+ nwb , alignment_times , channel , average = False , tw = tw_for_center_bin , data_column = data_col
72+ )
7373 n_centers = int (round ((tw [1 ] - tw [0 ]) / bin_size )) + 1
7474
7575 # bin the time values into discrete bins and compute bin centers
7676 left0 = tw_for_center_bin [0 ]
7777 edges = left0 + np .arange (n_centers + 1 ) * bin_size
78- aligns ['time_bin' ] = pd .cut (aligns ['time' ], bins = edges , right = False , include_lowest = True )
79- aligns ['bin_center' ] = aligns ['time_bin' ].apply (lambda iv : (iv .left + float (bin_size ) / 2.0 ) if pd .notnull (iv ) else np .nan )
78+ aligns ["time_bin" ] = pd .cut (aligns ["time" ], bins = edges , right = False , include_lowest = True )
79+ aligns ["bin_center" ] = aligns ["time_bin" ].apply (
80+ lambda iv : (iv .left + float (bin_size ) / 2.0 ) if pd .notnull (iv ) else np .nan
81+ )
8082
81- aligns = aligns .dropna (subset = [' bin_center' , data_col ]).copy ()
83+ aligns = aligns .dropna (subset = [" bin_center" , data_col ]).copy ()
8284
8385 # average by bin_centers
84- agg_align = aligns .groupby (['bin_center' ,'event_number' ])[data_col ].mean ().unstack (['event_number' ])
86+ agg_align = (
87+ aligns .groupby (["bin_center" , "event_number" ])[data_col ].mean ().unstack (["event_number" ])
88+ )
8589 # drop any event_number with nan values for any bin_centers.
86- agg_align = agg_align .dropna (how = ' any' , axis = 1 )
90+ agg_align = agg_align .dropna (how = " any" , axis = 1 )
8791
8892 # calculate auROC
8993 aucs = []
9094 labels_valid = labels [agg_align .columns .values ]
9195 for bin_center , row in agg_align .iterrows ():
92- try :
93- auc_val = roc_auc_score (labels_valid , row .values )
94- except Exception :
95- auc_val = np .nan
96- aucs .append (auc_val )
97-
98- curr_auc_df = pd .DataFrame ({
99- 'bin_center' : agg_align .index .values ,
100- 'auc' : np .asarray (aucs , dtype = float )
101- })
96+ try :
97+ auc_val = roc_auc_score (labels_valid , row .values )
98+ except Exception :
99+ auc_val = np .nan
100+ aucs .append (auc_val )
102101
102+ curr_auc_df = pd .DataFrame (
103+ {"bin_center" : agg_align .index .values , "auc" : np .asarray (aucs , dtype = float )}
104+ )
103105
104106 return curr_auc_df
105-
106- def compute_auroc_multi (nwb_list , alignment_times_list , label_list , channel , tw , bin_size = 0.25 ):
107+
108+
109+ def compute_auroc_multi (nwb_list , alignment_times_list , label_list , channel , tw , bin_size = 0.25 ):
107110 """
108111 Compute auROC across multiple NWB sessions and return a session x time-bin table.
109112
@@ -123,28 +126,29 @@ def compute_auroc_multi(nwb_list, alignment_times_list, label_list, channel, tw,
123126
124127 Returns
125128 - pandas.DataFrame
126- Concatenated DataFrame where each row is a session (index = session_id) and each column is a bin_center;
127- cell values are the auROC for that session and bin. If no sessions produced results, an empty DataFrame is returned.
129+ Concatenated DataFrame where each row is a session (index = session_id)
130+ and each column is a bin_center; cell values are the auROC for that session
131+ and bin. If no sessions produced results, an empty DataFrame is returned.
128132 """
129133
130134 if len (nwb_list ) != len (alignment_times_list ) or len (nwb_list ) != len (label_list ):
131135 raise ValueError ("nwb_list, alignment_times_list, label_list must have the same length" )
132-
136+
133137 # across sessions, should alway use z-scored data to compare
134- data_col = ' data_z'
135-
138+ data_col = " data_z"
139+
136140 auc_df_list = []
137141 for nwb , align_times , labels in zip (nwb_list , alignment_times_list , label_list ):
138142 auc_df = compute_auroc (nwb , align_times , labels , channel , tw , bin_size , data_col )
139143 if auc_df .empty :
140144 continue
141- auc_df [' session_id' ] = nwb .session_id
145+ auc_df [" session_id" ] = nwb .session_id
142146 # pivot to single-row DataFrame: index=session_id, columns=bin_center, values=auc
143- row = auc_df .pivot (index = ' session_id' , columns = ' bin_center' , values = ' auc' )
147+ row = auc_df .pivot (index = " session_id" , columns = " bin_center" , values = " auc" )
144148 auc_df_list .append (row )
145-
149+
146150 if len (auc_df_list ) == 0 :
147151 return pd .DataFrame ()
148152
149153 # Concatenate all DataFrames in the list
150- return pd .concat (auc_df_list , axis = 0 )
154+ return pd .concat (auc_df_list , axis = 0 )
0 commit comments