|
1 | 1 | import numpy as np |
2 | 2 | import pandas as pd |
3 | | -from analysis_wrapper.plots import summary_plots |
4 | 3 | from aind_dynamic_foraging_basic_analysis.metrics import trial_metrics |
5 | 4 |
|
6 | 5 |
|
| 6 | + |
| 7 | +def get_RPE_by_avg_signal_fit(data, avg_signal_col): |
| 8 | + |
| 9 | + |
| 10 | + x = data['RPE_earned'].values |
| 11 | + y = data[avg_signal_col].values |
| 12 | + try: |
| 13 | + lr = stats.linregress(x, y) |
| 14 | + x_fit = np.linspace(x.min(), x.max(), 100) |
| 15 | + y_fit = lr.intercept + lr.slope * x_fit |
| 16 | + slope = lr.slope |
| 17 | + except ValueError as e: |
| 18 | + print(f"Error in linear regression: {e}") |
| 19 | + x_fit = np.nan * np.arange(100) |
| 20 | + y_fit = np.nan * np.arange(100) |
| 21 | + slope = np.nan |
| 22 | + return (x_fit, y_fit, slope) |
| 23 | + |
| 24 | + |
7 | 25 | def add_AUC_and_rpe_slope(nwbs_by_week, parameters, data_column = 'data_z_norm', |
8 | 26 | alignment_event = 'choice_time_in_session',offsets = [0.33,1]): |
9 | 27 | rpe_slope_dict = {} |
@@ -39,8 +57,8 @@ def add_AUC_and_rpe_slope(nwbs_by_week, parameters, data_column = 'data_z_norm', |
39 | 57 | data_pos = data[data['RPE_earned'] >= 0] |
40 | 58 |
|
41 | 59 | ses_date = pd.to_datetime(ses_idx.split('_')[1]) |
42 | | - (_,_, slope_pos) = summary_plots.get_RPE_by_avg_signal_fit(data_pos, avg_signal_col) |
43 | | - (_,_, slope_neg) = summary_plots.get_RPE_by_avg_signal_fit(data_neg, avg_signal_col) |
| 60 | + (_,_, slope_pos) = get_RPE_by_avg_signal_fit(data_pos, avg_signal_col) |
| 61 | + (_,_, slope_neg) = get_RPE_by_avg_signal_fit(data_neg, avg_signal_col) |
44 | 62 | rpe_slope.append([ses_date, slope_pos, slope_neg]) |
45 | 63 | rpe_slope = pd.DataFrame(rpe_slope, columns=['date', 'slope (RPE >= 0)', 'slope (RPE < 0)']) |
46 | 64 | rpe_slope_dict[channel] = rpe_slope |
|
0 commit comments