Skip to content

Commit 5a56f09

Browse files
linting
1 parent 1413d7d commit 5a56f09

1 file changed

Lines changed: 54 additions & 50 deletions

File tree

Lines changed: 54 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,38 @@
11
"""
2-
Tools for computing per session metrics
3-
compute_auroc: compute auroc for one NWB given alignments
4-
compute_auroc_multi: compute auroc for multiple NWB given alignments
2+
Tools for computing per session metrics
3+
compute_auroc: compute auroc for one NWB given alignments
4+
compute_auroc_multi: compute auroc for multiple NWB given alignments
55
66
"""
77

8-
98
from sklearn.metrics import roc_auc_score
109
from aind_dynamic_foraging_basic_analysis.plot import plot_fip as pf
1110
import warnings
1211
import pandas as pd
1312
import numpy as np
1413

1514

16-
def compute_auroc(nwb, alignment_times, labels, channel, tw, bin_size = 0.25, data_col='data_z'):
15+
def compute_auroc(nwb, alignment_times, labels, channel, tw, bin_size=0.25, data_col="data_z"):
1716
"""
1817
Compute the time-resolved area under the ROC curve (auROC) for a single NWB session.
1918
2019
Parameters
2120
- nwb: object
22-
NWB session object expected to contain a DataFrame `df_fip` with FIP data and a `session_id`.
21+
NWB session object expected to contain a DataFrame `df_fip` with
22+
FIP data and a `session_id`.
2323
- alignment_times: array-like, shape (n_trials,)
2424
Times to align trials to (seconds), given in session time
2525
- labels: array-like, shape (n_trials,)
26-
Binary labels (0/1) for each alignment time. Must have same length as alignment_times.
26+
Binary labels (0/1) for each alignment time. Must have same
27+
length as alignment_times.
2728
- channel: str
2829
Channel name to select from `nwb.df_fip.event`.
2930
- tw: tuple (start, end)
30-
Time window (seconds) around the alignment to compute auROC over (centered bins will be between tw[0] and tw[1]).
31+
Time window (seconds) around the alignment to compute auROC over
32+
(centered bins will be between tw[0] and tw[1]).
3133
- bin_size: float, optional
32-
Width (seconds) of each time bin used to aggregate values before computing auROC. Default 0.25s.
34+
Width (seconds) of each time bin used to aggregate values
35+
before computing auROC. Default 0.25s.
3336
- data_col: str, optional
3437
Column name in the FIP data to use for values (default is z-scored data, 'data_z').
3538
@@ -38,72 +41,72 @@ def compute_auroc(nwb, alignment_times, labels, channel, tw, bin_size = 0.25, da
3841
DataFrame with columns:
3942
- 'bin_center': center time of each bin (seconds)
4043
- 'auc': auROC value for that bin (NaN when computation failed)
41-
If the requested channel is not present in the NWB, returns an empty DataFrame with those columns.
44+
If the requested channel is not present in the NWB,
45+
returns an empty DataFrame with those columns.
4246
4347
Notes
4448
- alignment_times and labels are sorted together before computing PSTHs.
45-
- Trials with NaNs in the aggregated bin are dropped; event_numbers that contain any NaNs across bins are removed.
49+
- Trials with NaNs in the aggregated bin are dropped;
50+
event_numbers that contain any NaNs across bins are removed.
4651
"""
4752
if len(labels) != len(alignment_times):
48-
raise Exception('Alignment times must have same number of labels ')
53+
raise Exception("Alignment times must have same number of labels ")
4954

5055
if np.unique(labels).size > 2:
51-
raise Exception('Labels must be binary for auROC computation')
52-
56+
raise Exception("Labels must be binary for auROC computation")
57+
5358
if channel not in nwb.df_fip.event.unique():
5459
warnings.warn("No channel found in this NWB, returning empty DataFrame")
55-
return pd.DataFrame(columns=['bin_center', 'auc'])
60+
return pd.DataFrame(columns=["bin_center", "auc"])
5661

5762
# sort labels and alignment times
5863
sorted_indices = np.argsort(alignment_times)
5964
alignment_times = alignment_times[sorted_indices]
6065
labels = labels[sorted_indices]
6166

62-
tw_for_center_bin = [tw[0] - bin_size/2, tw[1] + bin_size/2]
67+
tw_for_center_bin = [tw[0] - bin_size / 2, tw[1] + bin_size / 2]
6368

64-
# get alignments
69+
# get alignments
6570
aligns = pf.fip_psth_inner_compute(
66-
nwb,
67-
alignment_times,
68-
channel,
69-
average = False,
70-
tw=tw_for_center_bin,
71-
data_column=data_col
72-
)
71+
nwb, alignment_times, channel, average=False, tw=tw_for_center_bin, data_column=data_col
72+
)
7373
n_centers = int(round((tw[1] - tw[0]) / bin_size)) + 1
7474

7575
# bin the time values into discrete bins and compute bin centers
7676
left0 = tw_for_center_bin[0]
7777
edges = left0 + np.arange(n_centers + 1) * bin_size
78-
aligns['time_bin'] = pd.cut(aligns['time'], bins=edges, right=False, include_lowest=True)
79-
aligns['bin_center'] = aligns['time_bin'].apply(lambda iv: (iv.left + float(bin_size) / 2.0) if pd.notnull(iv) else np.nan)
78+
aligns["time_bin"] = pd.cut(aligns["time"], bins=edges, right=False, include_lowest=True)
79+
aligns["bin_center"] = aligns["time_bin"].apply(
80+
lambda iv: (iv.left + float(bin_size) / 2.0) if pd.notnull(iv) else np.nan
81+
)
8082

81-
aligns = aligns.dropna(subset=['bin_center',data_col]).copy()
83+
aligns = aligns.dropna(subset=["bin_center", data_col]).copy()
8284

8385
# average by bin_centers
84-
agg_align = aligns.groupby(['bin_center','event_number'])[data_col].mean().unstack(['event_number'])
86+
agg_align = (
87+
aligns.groupby(["bin_center", "event_number"])[data_col].mean().unstack(["event_number"])
88+
)
8589
# drop any event_number with nan values for any bin_centers.
86-
agg_align = agg_align.dropna(how='any', axis=1)
90+
agg_align = agg_align.dropna(how="any", axis=1)
8791

8892
# calculate auROC
8993
aucs = []
9094
labels_valid = labels[agg_align.columns.values]
9195
for bin_center, row in agg_align.iterrows():
92-
try:
93-
auc_val = roc_auc_score(labels_valid, row.values)
94-
except Exception:
95-
auc_val = np.nan
96-
aucs.append(auc_val)
97-
98-
curr_auc_df = pd.DataFrame({
99-
'bin_center': agg_align.index.values,
100-
'auc': np.asarray(aucs, dtype=float)
101-
})
96+
try:
97+
auc_val = roc_auc_score(labels_valid, row.values)
98+
except Exception:
99+
auc_val = np.nan
100+
aucs.append(auc_val)
102101

102+
curr_auc_df = pd.DataFrame(
103+
{"bin_center": agg_align.index.values, "auc": np.asarray(aucs, dtype=float)}
104+
)
103105

104106
return curr_auc_df
105-
106-
def compute_auroc_multi(nwb_list, alignment_times_list, label_list, channel, tw, bin_size = 0.25):
107+
108+
109+
def compute_auroc_multi(nwb_list, alignment_times_list, label_list, channel, tw, bin_size=0.25):
107110
"""
108111
Compute auROC across multiple NWB sessions and return a session x time-bin table.
109112
@@ -123,28 +126,29 @@ def compute_auroc_multi(nwb_list, alignment_times_list, label_list, channel, tw,
123126
124127
Returns
125128
- pandas.DataFrame
126-
Concatenated DataFrame where each row is a session (index = session_id) and each column is a bin_center;
127-
cell values are the auROC for that session and bin. If no sessions produced results, an empty DataFrame is returned.
129+
Concatenated DataFrame where each row is a session (index = session_id)
130+
and each column is a bin_center; cell values are the auROC for that session
131+
and bin. If no sessions produced results, an empty DataFrame is returned.
128132
"""
129133

130134
if len(nwb_list) != len(alignment_times_list) or len(nwb_list) != len(label_list):
131135
raise ValueError("nwb_list, alignment_times_list, label_list must have the same length")
132-
136+
133137
# across sessions, should alway use z-scored data to compare
134-
data_col='data_z'
135-
138+
data_col = "data_z"
139+
136140
auc_df_list = []
137141
for nwb, align_times, labels in zip(nwb_list, alignment_times_list, label_list):
138142
auc_df = compute_auroc(nwb, align_times, labels, channel, tw, bin_size, data_col)
139143
if auc_df.empty:
140144
continue
141-
auc_df['session_id'] = nwb.session_id
145+
auc_df["session_id"] = nwb.session_id
142146
# pivot to single-row DataFrame: index=session_id, columns=bin_center, values=auc
143-
row = auc_df.pivot(index='session_id', columns='bin_center', values='auc')
147+
row = auc_df.pivot(index="session_id", columns="bin_center", values="auc")
144148
auc_df_list.append(row)
145-
149+
146150
if len(auc_df_list) == 0:
147151
return pd.DataFrame()
148152

149153
# Concatenate all DataFrames in the list
150-
return pd.concat(auc_df_list, axis = 0)
154+
return pd.concat(auc_df_list, axis=0)

0 commit comments

Comments
 (0)