1- import warnings
2- import glob
3- import pandas as pd
41import numpy as np
2+ import pandas as pd
3+ from analysis_wrapper .plots import summary_plots
4+ from aind_dynamic_foraging_basic_analysis .metrics import trial_metrics
55
66
7- class dummy_nwb :
8- def __init__ (self , df_trials , df_events , df_fip , ses_idx = None , df_licks = None , grouped = False ) -> None :
9- if grouped is True :
10- self .df_events = df_events
11- self .df_fip = df_fip
12- self .df_trials = df_trials
13- self .session_id = ', ' .join (df_trials .ses_idx .unique ())
14- return
15- if ses_idx is None and grouped is False :
16-
17- if len (df_trials .ses_idx .unique ()) > 1 or \
18- len (df_events .ses_idx .unique ()) > 1 or \
19- len (df_fip .ses_idx .unique ()) > 1 :
20-
21- warnings .warn ('multiple sessions found, only one will be attached to this nwb' )
22- ses_idx = df_trials .ses_idx .unique ()[0 ]
23-
24-
25- assert df_fip [df_fip ['ses_idx' ] == ses_idx ].shape [0 ] != 0 ,(
26- "No session exists in the df_fip"
27- )
28- self .session_id = ses_idx
29- self .df_events = df_events [df_events ['ses_idx' ] == ses_idx ]
30- self .df_fip = df_fip [df_fip ['ses_idx' ] == ses_idx ].copy ().reset_index (drop = True )
31- self .df_trials = df_trials [df_trials ['ses_idx' ] == ses_idx ]
32- if df_licks :
33- self .df_licks = df_licks [df_licks ['ses_idx' ] == ses_idx ]
34-
35- nwb_file_name = glob .glob (f"/root/capsule/data/**{ ses_idx } **/nwb/**.nwb" )
36- if len (nwb_file_name ):
37- self .nwb_file_loc = nwb_file_name [0 ]
38- else :
39- self .nwb_file_loc = None
7+ def add_AUC_and_rpe_slope (nwbs_by_week , parameters , data_column = 'data_z_norm' ,
8+ alignment_event = 'choice_time_in_session' ,offsets = [0.33 ,1 ]):
9+ rpe_slope_dict = {}
10+ nwbs_by_week_enriched = []
11+ for channel in list (parameters ["channels" ].keys ()):
12+ if parameters ['preprocessing' ] is not 'raw' :
13+ channel = channel + '_' + parameters ['preprocessing' ]
14+
15+ avg_signal_col = summary_plots .output_col_name (channel , data_column , alignment_event )
16+ for nwb_week in nwbs_by_week :
17+
18+ nwb_week_enriched = trial_metrics .get_average_signal_window_multi (
19+ nwb_week ,
20+ alignment_event = alignment_event ,
21+ offsets = offsets ,
22+ channel = channel ,
23+ data_column = data_column ,
24+ output_col = avg_signal_col
25+ )
26+ nwbs_by_week_enriched .append (nwb_week_enriched )
4027
28+ # get rpe slope per session
29+
30+ df_trials_all = pd .concat ([nwb .df_trials for nwb_week in nwbs_by_week_enriched for nwb in nwb_week ])
31+ rpe_slope = []
32+ for ses_idx in sorted (df_trials_all ['ses_idx' ].unique ()):
33+
34+ data = df_trials_all [df_trials_all ['ses_idx' ] == ses_idx ]
35+ data = data .dropna (subset = [avg_signal_col , 'RPE_earned' ])
36+ if len (data ) == 0 :
37+ continue
38+ data_neg = data [data ['RPE_earned' ] < 0 ]
39+ data_pos = data [data ['RPE_earned' ] >= 0 ]
40+
41+ ses_date = pd .to_datetime (ses_idx .split ('_' )[1 ])
42+ (_ ,_ , slope_pos ) = summary_plots .get_RPE_by_avg_signal_fit (data_pos , avg_signal_col )
43+ (_ ,_ , slope_neg ) = summary_plots .get_RPE_by_avg_signal_fit (data_neg , avg_signal_col )
44+ rpe_slope .append ([ses_date , slope_pos , slope_neg ])
45+ rpe_slope = pd .DataFrame (rpe_slope , columns = ['date' , 'slope (RPE >= 0)' , 'slope (RPE < 0)' ])
46+ rpe_slope_dict [channel ] = rpe_slope
47+
48+ subject_id = str (nwbs_by_week_enriched [0 ][0 ]).split (' ' )[1 ].split ('_' )[0 ]
49+ # Concatenate with keys, turning dict keys into an index
50+ combined_rpe_slope = pd .concat (rpe_slope_dict , names = ["channel" ])
51+ combined_rpe_slope = combined_rpe_slope .reset_index (level = "channel" ).reset_index (drop = True )
52+
53+ combined_rpe_slope .to_csv (f"/results/{ subject_id } _rpe_slope.csv" )
54+
55+ return nwbs_by_week_enriched , combined_rpe_slope
56+
57+
58+ def enrich_df_trials (df_trials ):
59+
60+ ##### PART I: REWARD #######
61+ df_trials ['reward_all' ] = df_trials ['earned_reward' ] + df_trials ['extra_reward' ]
62+ # Compute num_reward_past and num_no_reward_past
63+ df_trials ['rewarded_prev' ] = df_trials .groupby ('ses_idx' )['reward_all' ].shift (1 ) # Shift to look at past values
4164
42- def __str__ ( self ):
43- return f"session { self . session_id } "
65+ df_trials [ 'num_reward_past' ] = df_trials . groupby (
66+ ( df_trials [ 'rewarded_prev' ] != df_trials [ 'reward_all' ]). cumsum ()). cumcount () + 1
4467
45- def __repr__ (self ):
46- return f"{ self .session_id } "
68+ # Set 'NA' for mismatched reward types
69+ df_trials .loc [df_trials ['reward_all' ] == 0 , 'num_reward_past' ] = df_trials .loc [df_trials ['reward_all' ] == 0 , 'num_reward_past' ]* - 1
70+
71+ ##### PART II: BINNING RPE #######
72+ # get RPE binned columns.
73+ RPE_binned3_label_names = [str (np .round (i ,2 )) for i in np .arange (- 1 ,0.99 ,1 / 3 )]
74+
75+ bins = np .arange (- 1 ,1.01 ,1 / 3 )
76+ bins [- 1 ] = 1.001
77+
78+ df_trials ['RPE-binned3' ] = pd .cut (df_trials ['RPE_earned' ],# all versus earned not a huge difference
79+ bins = bins , right = True , labels = RPE_binned3_label_names )
80+
81+ ##### PART III: BINNING QCHOSEN #######
82+ bins = [0.0 , 1 / 3 , 2 / 3 , 1.01 ]
83+ q_labels = ["Qch 0" , "Qch 0.33" , "Qch 0.66" ]
84+
85+ q_bin = pd .cut (df_trials ['Q_chosen' ], bins = bins , labels = q_labels , include_lowest = True , right = True )
86+ reward_label = df_trials ['earned_reward' ].map ({True : "R+" , False : "R-" })
87+
88+ # build combined label series (None where q_bin is NA)
89+ reward_Qcat_series = pd .Series (
90+ np .where (q_bin .isna (), None , reward_label .astype (str ) + " (" + q_bin .astype (str ) + ")" ),
91+ index = df_trials .index
92+ )
93+
94+ # ordered categories you requested
95+ Qch_binned3_label_names = [
96+ "R- (Qch 0)" , "R- (Qch 0.33)" , "R- (Qch 0.66)" ,
97+ "R+ (Qch 0)" , "R+ (Qch 0.33)" , "R+ (Qch 0.66)"
98+ ]
99+
100+ # assign final ordered categorical to dataframe (no intermediate column left behind)
101+ df_trials ['Qch-binned3' ] = pd .Categorical (reward_Qcat_series , categories = Qch_binned3_label_names , ordered = True )
47102
48103
49- def get_dummy_nwbs (df_trials , df_events , df_fip ):
50- ses_idx_list = df_trials .ses_idx .unique ()
51- dummy_nwbs_list = []
52- ses_dates_order = np .argsort (pd .to_datetime ([ses_idx .split ('_' )[1 ] for ses_idx in ses_idx_list ]))
53-
54- for ses_idx in ses_idx_list [ses_dates_order ]:
55- # Check if ses_idx exists in all 3 dataframes
56- if (
57- ses_idx in df_events ['ses_idx' ].values and
58- ses_idx in df_fip ['ses_idx' ].values and
59- ses_idx in df_trials ['ses_idx' ].values
60- ):
61- df_trials_i = df_trials [df_trials ['ses_idx' ] == ses_idx ]
62- df_events_i = df_events [df_events ['ses_idx' ] == ses_idx ]
63- df_fip_i = df_fip [df_fip ['ses_idx' ] == ses_idx ]
64-
65- dummy_nwbs_list .append (dummy_nwb (df_trials_i , df_events_i , df_fip_i ))
66- else :
67- warnings .warn (f"Skipping { ses_idx } : not found in all input DataFrames." , UserWarning )
68-
69- return dummy_nwbs_list
70-
71- def get_dummy_nwbs_by_subject (df_trials , df_events , df_fip ):
72- df_trials ['subject_id' ] = df_trials ['ses_idx' ].str .split ('_' ).str [0 ]
73- df_events ['subject_id' ] = df_events ['ses_idx' ].str .split ('_' ).str [0 ]
74- df_fip ['subject_id' ] = df_fip ['ses_idx' ].str .split ('_' ).str [0 ]
75- subject_id_list = df_trials .subject_id .unique ()
76- dummy_nwbs_list = []
77- for subject_id in subject_id_list :
78- # Check if ses_idx exists in all 3 dataframes
79- if (
80- subject_id in df_events ['subject_id' ].values and
81- subject_id in df_fip ['subject_id' ].values and
82- subject_id in df_trials ['subject_id' ].values
83- ):
84- df_trials_i = df_trials [df_trials ['subject_id' ] == subject_id ]
85- df_events_i = df_events [df_events ['subject_id' ] == subject_id ]
86- df_fip_i = df_fip [df_fip ['subject_id' ] == subject_id ]
87-
88- dummy_nwbs_list .append (get_dummy_nwbs (df_trials_i , df_events_i , df_fip_i ))
89- else :
90- warnings .warn (f"Skipping { subject_id } : not found in all input DataFrames." , UserWarning )
91-
92- return dummy_nwbs_list
93-
94- def get_date_and_week_interval (df , start_date ):
95- date_series = pd .to_datetime (df ['ses_idx' ].str .split ('_' ).str [1 ], format = '%Y-%m-%d' )
96- week_interval_series = ((date_series - start_date ).dt .days // 7 ) + 1
97- return week_interval_series
98-
99- def get_dummy_nwbs_by_week (df_sess ,df_trials , df_events , df_fip ):
100- start_date = pd .to_datetime (df_sess ['session_date' ].min ())
101-
102- df_sess ['week_interval' ] = get_date_and_week_interval (df_sess , start_date )
103- df_trials ['week_interval' ] = get_date_and_week_interval (df_trials , start_date )
104- df_events ['week_interval' ] = get_date_and_week_interval (df_events , start_date )
105- df_fip ['week_interval' ] = get_date_and_week_interval (df_fip , start_date )
106-
107- week_interval_list = df_trials .week_interval .unique ()
108- dummy_nwbs_list = []
109- for week_interval in week_interval_list :
110- # Check if ses_idx exists in all 3 dataframes
111- if (
112- week_interval in df_events ['week_interval' ].values and
113- week_interval in df_fip ['week_interval' ].values and
114- week_interval in df_trials ['week_interval' ].values
115- ):
116- df_trials_i = df_trials [df_trials ['week_interval' ] == week_interval ]
117- df_events_i = df_events [df_events ['week_interval' ] == week_interval ]
118- df_fip_i = df_fip [df_fip ['week_interval' ] == week_interval ]
119-
120- dummy_nwbs_list .append (get_dummy_nwbs (df_trials_i , df_events_i , df_fip_i ))
121- else :
122- warnings .warn (f"Skipping { week_interval } : not found in all input DataFrames." , UserWarning )
123-
124- return df_sess , dummy_nwbs_list
125-
126-
127-
128- def combine_dummy_nwbs_to_dfs (dummy_nwbs_list ):
129- """
130- Given a list of dummy_nwb objects, concatenate their df_trials, df_events, and df_fip
131- into three large DataFrames.
132-
133- Parameters
134- ----------
135- dummy_nwbs : list of dummy_nwb
136-
137- Returns
138- -------
139- tuple of pd.DataFrame
140- (df_trials_all, df_events_all, df_fip_all)
141- """
142-
143- df_trials_list = []
144- df_events_list = []
145- df_fip_list = []
146-
147- for nwb in dummy_nwbs_list :
148- df_trials_list .append (nwb .df_trials )
149- df_events_list .append (nwb .df_events )
150- df_fip_list .append (nwb .df_fip )
151-
152- df_trials_all = pd .concat (df_trials_list , ignore_index = True )
153- df_events_all = pd .concat (df_events_list , ignore_index = True )
154- df_fip_all = pd .concat (df_fip_list , ignore_index = True )
155-
156- return df_trials_all , df_events_all , df_fip_all
104+ ##### PART IV: GETTING STAY/LEAVE #######
105+ _choice_shifted = df_trials .groupby ('ses_idx' )['choice' ].shift (1 )
106+ df_trials ['stay' ] = df_trials ['choice' ] == _choice_shifted
107+ df_trials ['switch' ] = df_trials ['choice' ] != _choice_shifted
108+ df_trials ['response_time' ] = df_trials ['choice_time_in_trial' ] - df_trials ['goCue_start_time_in_trial' ]
109+
110+
111+ return df_trials
0 commit comments