Skip to content

Commit b832718

Browse files
added analysis_util enriching functions.
1 parent ec2d08c commit b832718

2 files changed

Lines changed: 260 additions & 148 deletions

File tree

Lines changed: 103 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -1,156 +1,111 @@
1-
import warnings
2-
import glob
3-
import pandas as pd
41
import numpy as np
2+
import pandas as pd
3+
from analysis_wrapper.plots import summary_plots
4+
from aind_dynamic_foraging_basic_analysis.metrics import trial_metrics
55

66

7-
class dummy_nwb:
8-
def __init__(self, df_trials, df_events, df_fip, ses_idx = None, df_licks = None, grouped = False) -> None:
9-
if grouped is True:
10-
self.df_events = df_events
11-
self.df_fip = df_fip
12-
self.df_trials = df_trials
13-
self.session_id = ', '.join(df_trials.ses_idx.unique())
14-
return
15-
if ses_idx is None and grouped is False:
16-
17-
if len(df_trials.ses_idx.unique()) > 1 or \
18-
len(df_events.ses_idx.unique()) > 1 or \
19-
len(df_fip.ses_idx.unique()) > 1:
20-
21-
warnings.warn('multiple sessions found, only one will be attached to this nwb')
22-
ses_idx = df_trials.ses_idx.unique()[0]
23-
24-
25-
assert df_fip[df_fip['ses_idx'] == ses_idx].shape[0] != 0 ,(
26-
"No session exists in the df_fip"
27-
)
28-
self.session_id = ses_idx
29-
self.df_events = df_events[df_events['ses_idx'] == ses_idx]
30-
self.df_fip = df_fip[df_fip['ses_idx'] == ses_idx].copy().reset_index(drop=True)
31-
self.df_trials = df_trials[df_trials['ses_idx'] == ses_idx]
32-
if df_licks:
33-
self.df_licks = df_licks[df_licks['ses_idx'] == ses_idx]
34-
35-
nwb_file_name = glob.glob(f"/root/capsule/data/**{ses_idx}**/nwb/**.nwb")
36-
if len(nwb_file_name):
37-
self.nwb_file_loc = nwb_file_name[0]
38-
else:
39-
self.nwb_file_loc = None
7+
def add_AUC_and_rpe_slope(nwbs_by_week, parameters, data_column = 'data_z_norm',
8+
alignment_event = 'choice_time_in_session',offsets = [0.33,1]):
9+
rpe_slope_dict = {}
10+
nwbs_by_week_enriched = []
11+
for channel in list(parameters["channels"].keys()):
12+
if parameters['preprocessing'] is not 'raw':
13+
channel = channel + '_' + parameters['preprocessing']
14+
15+
avg_signal_col = summary_plots.output_col_name(channel, data_column, alignment_event)
16+
for nwb_week in nwbs_by_week:
17+
18+
nwb_week_enriched = trial_metrics.get_average_signal_window_multi(
19+
nwb_week,
20+
alignment_event=alignment_event,
21+
offsets=offsets,
22+
channel=channel,
23+
data_column=data_column,
24+
output_col = avg_signal_col
25+
)
26+
nwbs_by_week_enriched.append(nwb_week_enriched)
4027

28+
# get rpe slope per session
29+
30+
df_trials_all = pd.concat([nwb.df_trials for nwb_week in nwbs_by_week_enriched for nwb in nwb_week])
31+
rpe_slope = []
32+
for ses_idx in sorted(df_trials_all['ses_idx'].unique()):
33+
34+
data = df_trials_all[df_trials_all['ses_idx'] == ses_idx]
35+
data = data.dropna(subset = [avg_signal_col, 'RPE_earned'])
36+
if len(data) == 0:
37+
continue
38+
data_neg = data[data['RPE_earned'] < 0]
39+
data_pos = data[data['RPE_earned'] >= 0]
40+
41+
ses_date = pd.to_datetime(ses_idx.split('_')[1])
42+
(_,_, slope_pos) = summary_plots.get_RPE_by_avg_signal_fit(data_pos, avg_signal_col)
43+
(_,_, slope_neg) = summary_plots.get_RPE_by_avg_signal_fit(data_neg, avg_signal_col)
44+
rpe_slope.append([ses_date, slope_pos, slope_neg])
45+
rpe_slope = pd.DataFrame(rpe_slope, columns=['date', 'slope (RPE >= 0)', 'slope (RPE < 0)'])
46+
rpe_slope_dict[channel] = rpe_slope
47+
48+
subject_id = str(nwbs_by_week_enriched[0][0]).split(' ')[1].split('_')[0]
49+
# Concatenate with keys, turning dict keys into an index
50+
combined_rpe_slope = pd.concat(rpe_slope_dict, names=["channel"])
51+
combined_rpe_slope = combined_rpe_slope.reset_index(level="channel").reset_index(drop=True)
52+
53+
combined_rpe_slope.to_csv(f"/results/{subject_id}_rpe_slope.csv")
54+
55+
return nwbs_by_week_enriched, combined_rpe_slope
56+
57+
58+
def enrich_df_trials(df_trials):
59+
60+
##### PART I: REWARD #######
61+
df_trials['reward_all'] = df_trials['earned_reward'] + df_trials['extra_reward']
62+
# Compute num_reward_past and num_no_reward_past
63+
df_trials['rewarded_prev'] = df_trials.groupby('ses_idx')['reward_all'].shift(1) # Shift to look at past values
4164

42-
def __str__(self):
43-
return f"session {self.session_id}"
65+
df_trials['num_reward_past'] = df_trials.groupby(
66+
(df_trials['rewarded_prev'] != df_trials['reward_all']).cumsum()).cumcount() + 1
4467

45-
def __repr__(self):
46-
return f"{self.session_id}"
68+
# Set 'NA' for mismatched reward types
69+
df_trials.loc[df_trials['reward_all'] == 0, 'num_reward_past'] = df_trials.loc[df_trials['reward_all'] == 0, 'num_reward_past']* -1
70+
71+
##### PART II: BINNING RPE #######
72+
# get RPE binned columns.
73+
RPE_binned3_label_names = [str(np.round(i,2)) for i in np.arange(-1,0.99,1/3)]
74+
75+
bins = np.arange(-1,1.01,1/3)
76+
bins[-1] = 1.001
77+
78+
df_trials['RPE-binned3'] = pd.cut(df_trials['RPE_earned'],# all versus earned not a huge difference
79+
bins = bins, right = True, labels=RPE_binned3_label_names)
80+
81+
##### PART III: BINNING QCHOSEN #######
82+
bins = [0.0, 1/3, 2/3, 1.01]
83+
q_labels = ["Qch 0", "Qch 0.33", "Qch 0.66"]
84+
85+
q_bin = pd.cut(df_trials['Q_chosen'], bins=bins, labels=q_labels, include_lowest=True, right=True)
86+
reward_label = df_trials['earned_reward'].map({True: "R+", False: "R-"})
87+
88+
# build combined label series (None where q_bin is NA)
89+
reward_Qcat_series = pd.Series(
90+
np.where(q_bin.isna(), None, reward_label.astype(str) + " (" + q_bin.astype(str) + ")"),
91+
index=df_trials.index
92+
)
93+
94+
# ordered categories you requested
95+
Qch_binned3_label_names = [
96+
"R- (Qch 0)", "R- (Qch 0.33)", "R- (Qch 0.66)",
97+
"R+ (Qch 0)", "R+ (Qch 0.33)", "R+ (Qch 0.66)"
98+
]
99+
100+
# assign final ordered categorical to dataframe (no intermediate column left behind)
101+
df_trials['Qch-binned3'] = pd.Categorical(reward_Qcat_series, categories=Qch_binned3_label_names, ordered=True)
47102

48103

49-
def get_dummy_nwbs(df_trials, df_events, df_fip):
50-
ses_idx_list = df_trials.ses_idx.unique()
51-
dummy_nwbs_list = []
52-
ses_dates_order = np.argsort(pd.to_datetime([ses_idx.split('_')[1] for ses_idx in ses_idx_list]))
53-
54-
for ses_idx in ses_idx_list[ses_dates_order]:
55-
# Check if ses_idx exists in all 3 dataframes
56-
if (
57-
ses_idx in df_events['ses_idx'].values and
58-
ses_idx in df_fip['ses_idx'].values and
59-
ses_idx in df_trials['ses_idx'].values
60-
):
61-
df_trials_i = df_trials[df_trials['ses_idx'] == ses_idx]
62-
df_events_i = df_events[df_events['ses_idx'] == ses_idx]
63-
df_fip_i = df_fip[df_fip['ses_idx'] == ses_idx]
64-
65-
dummy_nwbs_list.append(dummy_nwb(df_trials_i, df_events_i, df_fip_i))
66-
else:
67-
warnings.warn(f"Skipping {ses_idx}: not found in all input DataFrames.", UserWarning)
68-
69-
return dummy_nwbs_list
70-
71-
def get_dummy_nwbs_by_subject(df_trials, df_events, df_fip):
72-
df_trials['subject_id'] = df_trials['ses_idx'].str.split('_').str[0]
73-
df_events['subject_id'] = df_events['ses_idx'].str.split('_').str[0]
74-
df_fip['subject_id'] = df_fip['ses_idx'].str.split('_').str[0]
75-
subject_id_list = df_trials.subject_id.unique()
76-
dummy_nwbs_list = []
77-
for subject_id in subject_id_list:
78-
# Check if ses_idx exists in all 3 dataframes
79-
if (
80-
subject_id in df_events['subject_id'].values and
81-
subject_id in df_fip['subject_id'].values and
82-
subject_id in df_trials['subject_id'].values
83-
):
84-
df_trials_i = df_trials[df_trials['subject_id'] == subject_id]
85-
df_events_i = df_events[df_events['subject_id'] == subject_id]
86-
df_fip_i = df_fip[df_fip['subject_id'] == subject_id]
87-
88-
dummy_nwbs_list.append(get_dummy_nwbs(df_trials_i, df_events_i, df_fip_i))
89-
else:
90-
warnings.warn(f"Skipping {subject_id}: not found in all input DataFrames.", UserWarning)
91-
92-
return dummy_nwbs_list
93-
94-
def get_date_and_week_interval(df, start_date):
95-
date_series = pd.to_datetime(df['ses_idx'].str.split('_').str[1], format='%Y-%m-%d')
96-
week_interval_series = ((date_series - start_date).dt.days // 7) + 1
97-
return week_interval_series
98-
99-
def get_dummy_nwbs_by_week(df_sess,df_trials, df_events, df_fip):
100-
start_date = pd.to_datetime(df_sess['session_date'].min())
101-
102-
df_sess['week_interval'] = get_date_and_week_interval(df_sess, start_date)
103-
df_trials['week_interval'] = get_date_and_week_interval(df_trials, start_date)
104-
df_events['week_interval'] = get_date_and_week_interval(df_events, start_date)
105-
df_fip['week_interval'] = get_date_and_week_interval(df_fip, start_date)
106-
107-
week_interval_list = df_trials.week_interval.unique()
108-
dummy_nwbs_list = []
109-
for week_interval in week_interval_list:
110-
# Check if ses_idx exists in all 3 dataframes
111-
if (
112-
week_interval in df_events['week_interval'].values and
113-
week_interval in df_fip['week_interval'].values and
114-
week_interval in df_trials['week_interval'].values
115-
):
116-
df_trials_i = df_trials[df_trials['week_interval'] == week_interval]
117-
df_events_i = df_events[df_events['week_interval'] == week_interval]
118-
df_fip_i = df_fip[df_fip['week_interval'] == week_interval]
119-
120-
dummy_nwbs_list.append(get_dummy_nwbs(df_trials_i, df_events_i, df_fip_i))
121-
else:
122-
warnings.warn(f"Skipping {week_interval}: not found in all input DataFrames.", UserWarning)
123-
124-
return df_sess, dummy_nwbs_list
125-
126-
127-
128-
def combine_dummy_nwbs_to_dfs(dummy_nwbs_list):
129-
"""
130-
Given a list of dummy_nwb objects, concatenate their df_trials, df_events, and df_fip
131-
into three large DataFrames.
132-
133-
Parameters
134-
----------
135-
dummy_nwbs : list of dummy_nwb
136-
137-
Returns
138-
-------
139-
tuple of pd.DataFrame
140-
(df_trials_all, df_events_all, df_fip_all)
141-
"""
142-
143-
df_trials_list = []
144-
df_events_list = []
145-
df_fip_list = []
146-
147-
for nwb in dummy_nwbs_list:
148-
df_trials_list.append(nwb.df_trials)
149-
df_events_list.append(nwb.df_events)
150-
df_fip_list.append(nwb.df_fip)
151-
152-
df_trials_all = pd.concat(df_trials_list, ignore_index=True)
153-
df_events_all = pd.concat(df_events_list, ignore_index=True)
154-
df_fip_all = pd.concat(df_fip_list, ignore_index=True)
155-
156-
return df_trials_all, df_events_all, df_fip_all
104+
##### PART IV: GETTING STAY/LEAVE #######
105+
_choice_shifted = df_trials.groupby('ses_idx')['choice'].shift(1)
106+
df_trials['stay'] = df_trials['choice'] == _choice_shifted
107+
df_trials['switch'] = df_trials['choice'] != _choice_shifted
108+
df_trials['response_time'] = df_trials['choice_time_in_trial'] - df_trials['goCue_start_time_in_trial']
109+
110+
111+
return df_trials

0 commit comments

Comments
 (0)