|
12 | 12 |
|
13 | 13 | import copy |
14 | 14 |
|
| 15 | +def pick_side(side_pos, seg_bounds, per_seg_min): |
| 16 | + side_by_seg = [] |
| 17 | + for lo, hi in seg_bounds: |
| 18 | + side_in = side_pos[(side_pos >= lo) & (side_pos < hi)] |
| 19 | + side_in_subset = np.random.choice(side_in, size = per_seg_min, replace = False) |
| 20 | + side_by_seg.append(np.sort(side_in_subset)) |
| 21 | + all_selected = np.sort(np.concatenate([arr for arr in side_by_seg if arr.size])) |
| 22 | + return all_selected, side_by_seg |
| 23 | + |
| 24 | +def subsample_lr_thirds(nwb, per_seg_min = 50): |
| 25 | + nwb_subset = copy.deepcopy(nwb) |
| 26 | + df_trials = nwb.df_trials |
| 27 | + n = len(df_trials) |
| 28 | + |
| 29 | + # thirds boundaries (integer division) |
| 30 | + t1 = n // 3 |
| 31 | + t2 = 2 * n // 3 |
| 32 | + |
| 33 | + positions = np.arange(n) |
| 34 | + |
| 35 | + left_mask = df_trials['choice'].to_numpy() == 0 |
| 36 | + right_mask = df_trials['choice'].to_numpy() == 1 |
| 37 | + |
| 38 | + left_pos = positions[left_mask] |
| 39 | + right_pos = positions[right_mask] |
| 40 | + |
| 41 | + # compute per-third minimums |
| 42 | + seg_bounds = [(0, t1), (t1, t2), (t2, n)] |
| 43 | + for lo, hi in seg_bounds: |
| 44 | + left_in = left_pos[(left_pos >= lo) & (left_pos < hi)] |
| 45 | + right_in = right_pos[(right_pos >= lo) & (right_pos < hi)] |
| 46 | + per_seg_min = int(min(int(min(left_in.size, right_in.size)), per_seg_min)) |
| 47 | + |
| 48 | + left_idx, left_by_seg = pick_side(left_pos, seg_bounds, per_seg_min) |
| 49 | + right_idx, right_by_seg = pick_side(right_pos, seg_bounds, per_seg_min) |
| 50 | + |
| 51 | + # ensure equal totals (they should be by construction) |
| 52 | + keep = min(left_idx.size, right_idx.size) |
| 53 | + left_idx = left_idx[:keep] |
| 54 | + right_idx = right_idx[:keep] |
| 55 | + |
| 56 | + sel_idx = np.sort(np.concatenate([left_idx, right_idx])) |
| 57 | + df_sub = df_trials.loc[sel_idx].copy().reset_index(drop=True) |
| 58 | + |
| 59 | + nwb_subset.df_trials = df_sub |
| 60 | + return nwb_subset |
| 61 | + |
15 | 62 | def split_nwb_by_choice(nwb): |
16 | 63 | nwb_split = copy.deepcopy(nwb) |
17 | 64 | nwb_split.df_trials_left = nwb.df_trials.query('choice == 0.0') |
|
0 commit comments