Skip to content

Commit f046db8

Browse files
moved a function to analysis_utils
1 parent b832718 commit f046db8

1 file changed

Lines changed: 21 additions & 3 deletions

File tree

src/rachel_analysis_utils/analysis_utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
11
import numpy as np
22
import pandas as pd
3-
from analysis_wrapper.plots import summary_plots
43
from aind_dynamic_foraging_basic_analysis.metrics import trial_metrics
54

65

6+
7+
def get_RPE_by_avg_signal_fit(data, avg_signal_col):
8+
9+
10+
x = data['RPE_earned'].values
11+
y = data[avg_signal_col].values
12+
try:
13+
lr = stats.linregress(x, y)
14+
x_fit = np.linspace(x.min(), x.max(), 100)
15+
y_fit = lr.intercept + lr.slope * x_fit
16+
slope = lr.slope
17+
except ValueError as e:
18+
print(f"Error in linear regression: {e}")
19+
x_fit = np.nan * np.arange(100)
20+
y_fit = np.nan * np.arange(100)
21+
slope = np.nan
22+
return (x_fit, y_fit, slope)
23+
24+
725
def add_AUC_and_rpe_slope(nwbs_by_week, parameters, data_column = 'data_z_norm',
826
alignment_event = 'choice_time_in_session',offsets = [0.33,1]):
927
rpe_slope_dict = {}
@@ -39,8 +57,8 @@ def add_AUC_and_rpe_slope(nwbs_by_week, parameters, data_column = 'data_z_norm',
3957
data_pos = data[data['RPE_earned'] >= 0]
4058

4159
ses_date = pd.to_datetime(ses_idx.split('_')[1])
42-
(_,_, slope_pos) = summary_plots.get_RPE_by_avg_signal_fit(data_pos, avg_signal_col)
43-
(_,_, slope_neg) = summary_plots.get_RPE_by_avg_signal_fit(data_neg, avg_signal_col)
60+
(_,_, slope_pos) = get_RPE_by_avg_signal_fit(data_pos, avg_signal_col)
61+
(_,_, slope_neg) = get_RPE_by_avg_signal_fit(data_neg, avg_signal_col)
4462
rpe_slope.append([ses_date, slope_pos, slope_neg])
4563
rpe_slope = pd.DataFrame(rpe_slope, columns=['date', 'slope (RPE >= 0)', 'slope (RPE < 0)'])
4664
rpe_slope_dict[channel] = rpe_slope

0 commit comments

Comments
 (0)