Skip to content

Commit a1cb985

Browse files
added first pass subsample
1 parent 31c6eeb commit a1cb985

1 file changed

Lines changed: 47 additions & 0 deletions

File tree

src/rachel_analysis_utils/nwb_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,53 @@
1212

1313
import copy
1414

15+
def pick_side(side_pos, seg_bounds, per_seg_min):
16+
side_by_seg = []
17+
for lo, hi in seg_bounds:
18+
side_in = side_pos[(side_pos >= lo) & (side_pos < hi)]
19+
side_in_subset = np.random.choice(side_in, size = per_seg_min, replace = False)
20+
side_by_seg.append(np.sort(side_in_subset))
21+
all_selected = np.sort(np.concatenate([arr for arr in side_by_seg if arr.size]))
22+
return all_selected, side_by_seg
23+
24+
def subsample_lr_thirds(nwb, per_seg_min = 50):
25+
nwb_subset = copy.deepcopy(nwb)
26+
df_trials = nwb.df_trials
27+
n = len(df_trials)
28+
29+
# thirds boundaries (integer division)
30+
t1 = n // 3
31+
t2 = 2 * n // 3
32+
33+
positions = np.arange(n)
34+
35+
left_mask = df_trials['choice'].to_numpy() == 0
36+
right_mask = df_trials['choice'].to_numpy() == 1
37+
38+
left_pos = positions[left_mask]
39+
right_pos = positions[right_mask]
40+
41+
# compute per-third minimums
42+
seg_bounds = [(0, t1), (t1, t2), (t2, n)]
43+
for lo, hi in seg_bounds:
44+
left_in = left_pos[(left_pos >= lo) & (left_pos < hi)]
45+
right_in = right_pos[(right_pos >= lo) & (right_pos < hi)]
46+
per_seg_min = int(min(int(min(left_in.size, right_in.size)), per_seg_min))
47+
48+
left_idx, left_by_seg = pick_side(left_pos, seg_bounds, per_seg_min)
49+
right_idx, right_by_seg = pick_side(right_pos, seg_bounds, per_seg_min)
50+
51+
# ensure equal totals (they should be by construction)
52+
keep = min(left_idx.size, right_idx.size)
53+
left_idx = left_idx[:keep]
54+
right_idx = right_idx[:keep]
55+
56+
sel_idx = np.sort(np.concatenate([left_idx, right_idx]))
57+
df_sub = df_trials.loc[sel_idx].copy().reset_index(drop=True)
58+
59+
nwb_subset.df_trials = df_sub
60+
return nwb_subset
61+
1562
def split_nwb_by_choice(nwb):
1663
nwb_split = copy.deepcopy(nwb)
1764
nwb_split.df_trials_left = nwb.df_trials.query('choice == 0.0')

0 commit comments

Comments
 (0)