Skip to content

Commit 041c072

Browse files
Merge pull request #103 from AllenNeuralDynamics/auroc
adding auroc to session_metrics
2 parents c07fb51 + 88b369d commit 041c072

1 file changed

Lines changed: 156 additions & 0 deletions

File tree

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""
2+
Tools for computing per session metrics
3+
compute_auroc: compute auroc for one NWB given alignments
4+
compute_auroc_multi: compute auroc for multiple NWB given alignments
5+
6+
"""
7+
8+
from sklearn.metrics import roc_auc_score
9+
from aind_dynamic_foraging_basic_analysis.plot import plot_fip as pf
10+
import warnings
11+
import pandas as pd
12+
import numpy as np
13+
14+
15+
def compute_auroc(nwb, alignment_times, labels, channel, tw, bin_size=0.25, data_col="data_z"):
16+
"""
17+
Compute the time-resolved area under the ROC curve (auROC) for a single NWB session.
18+
19+
Parameters
20+
- nwb: object
21+
NWB session object expected to contain a DataFrame `df_fip` with
22+
FIP data and a `session_id`.
23+
- alignment_times: array-like, shape (n_trials,)
24+
Times to align trials to (seconds), given in session time
25+
- labels: array-like, shape (n_trials,)
26+
Binary labels (0/1) for each alignment time. Must have same
27+
length as alignment_times.
28+
- channel: str
29+
Channel name to select from `nwb.df_fip.event`.
30+
- tw: tuple (start, end)
31+
Time window (seconds) around the alignment to compute auROC over
32+
(centered bins will be between tw[0] and tw[1]).
33+
- bin_size: float, optional
34+
Width (seconds) of each time bin used to aggregate values
35+
before computing auROC. Default 0.25s.
36+
- data_col: str, optional
37+
Column name in the FIP data to use for values (default is z-scored data, 'data_z').
38+
39+
Returns
40+
- pandas.DataFrame
41+
DataFrame with columns:
42+
- 'bin_center': center time of each bin (seconds)
43+
- 'auc': auROC value for that bin (NaN when computation failed)
44+
If the requested channel is not present in the NWB,
45+
returns an empty DataFrame with those columns.
46+
47+
Notes
48+
- alignment_times and labels are sorted together before computing PSTHs.
49+
- Trials with NaNs in the aggregated bin are dropped;
50+
event_numbers that contain any NaNs across bins are removed.
51+
"""
52+
if len(labels) != len(alignment_times):
53+
raise Exception("Alignment times must have same number of labels ")
54+
55+
if np.unique(labels).size > 2:
56+
raise Exception("Labels must be binary for auROC computation")
57+
58+
if channel not in nwb.df_fip.event.unique():
59+
warnings.warn("No channel found in this NWB, returning empty DataFrame")
60+
return pd.DataFrame(columns=["bin_center", "auc"])
61+
62+
# sort labels and alignment times
63+
sorted_indices = np.argsort(alignment_times)
64+
alignment_times = alignment_times[sorted_indices]
65+
labels = labels[sorted_indices]
66+
67+
tw_for_center_bin = [tw[0] - bin_size / 2, tw[1] + bin_size / 2]
68+
69+
# get alignments
70+
aligns = pf.fip_psth_inner_compute(
71+
nwb, alignment_times, channel, average=False, tw=tw_for_center_bin, data_column=data_col
72+
)
73+
n_centers = int(round((tw[1] - tw[0]) / bin_size)) + 1
74+
75+
# bin the time values into discrete bins and compute bin centers
76+
left0 = tw_for_center_bin[0]
77+
edges = left0 + np.arange(n_centers + 1) * bin_size
78+
aligns["time_bin"] = pd.cut(aligns["time"], bins=edges, right=False, include_lowest=True)
79+
aligns["bin_center"] = aligns["time_bin"].apply(
80+
lambda iv: (iv.left + float(bin_size) / 2.0) if pd.notnull(iv) else np.nan
81+
)
82+
83+
aligns = aligns.dropna(subset=["bin_center", data_col]).copy()
84+
85+
# average by bin_centers
86+
agg_align = (
87+
aligns.groupby(["bin_center", "event_number"], observed=True)[data_col]
88+
.mean()
89+
.unstack(["event_number"])
90+
)
91+
# drop any event_number with nan values for any bin_centers.
92+
agg_align = agg_align.dropna(how="any", axis=1)
93+
94+
# calculate auROC
95+
aucs = []
96+
labels_valid = labels[agg_align.columns.values]
97+
for bin_center, row in agg_align.iterrows():
98+
try:
99+
auc_val = roc_auc_score(labels_valid, row.values)
100+
except Exception:
101+
auc_val = np.nan
102+
aucs.append(auc_val)
103+
104+
curr_auc_df = pd.DataFrame(
105+
{"bin_center": agg_align.index.values, "auc": np.asarray(aucs, dtype=float)}
106+
)
107+
108+
return curr_auc_df
109+
110+
111+
def compute_auroc_multi(nwb_list, alignment_times_list, label_list, channel, tw, bin_size=0.25):
112+
"""
113+
Compute auROC across multiple NWB sessions and return a session x time-bin table.
114+
115+
Parameters
116+
- nwb_list: sequence of NWB objects
117+
Each element should provide FIP data and a `session_id`.
118+
- alignment_times_list: sequence of array-like
119+
Per-session alignment times; must be same length as nwb_list.
120+
- label_list: sequence of array-like
121+
Per-session labels corresponding to alignment times; must be same length as nwb_list.
122+
- channel: str
123+
Channel name to use in each NWB.
124+
- tw: tuple (start, end)
125+
Time window (seconds) around alignments to compute auROC over.
126+
- bin_size: float, optional
127+
Time bin width for aggregation (default 0.25s).
128+
129+
Returns
130+
- pandas.DataFrame
131+
Concatenated DataFrame where each row is a session (index = session_id)
132+
and each column is a bin_center; cell values are the auROC for that session
133+
and bin. If no sessions produced results, an empty DataFrame is returned.
134+
"""
135+
136+
if len(nwb_list) != len(alignment_times_list) or len(nwb_list) != len(label_list):
137+
raise ValueError("nwb_list, alignment_times_list, label_list must have the same length")
138+
139+
# across sessions, should alway use z-scored data to compare
140+
data_col = "data_z"
141+
142+
auc_df_list = []
143+
for nwb, align_times, labels in zip(nwb_list, alignment_times_list, label_list):
144+
auc_df = compute_auroc(nwb, align_times, labels, channel, tw, bin_size, data_col)
145+
if auc_df.empty:
146+
continue
147+
auc_df["session_id"] = nwb.session_id
148+
# pivot to single-row DataFrame: index=session_id, columns=bin_center, values=auc
149+
row = auc_df.pivot(index="session_id", columns="bin_center", values="auc")
150+
auc_df_list.append(row)
151+
152+
if len(auc_df_list) == 0:
153+
return pd.DataFrame()
154+
155+
# Concatenate all DataFrames in the list
156+
return pd.concat(auc_df_list, axis=0)

0 commit comments

Comments
 (0)