Skip to content

Commit cb20fce

Browse files
authored
Merge pull request #85 from AllenNeuralDynamics/kh-dev_snr-metrics
add snr and kurtosis metrics functions for quality assesment and data curation
2 parents 08071ec + e170f02 commit cb20fce

1 file changed

Lines changed: 134 additions & 0 deletions

File tree

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Utilities for signal quality metrics on 1D fluorescence traces.
3+
4+
This module provides:
5+
- :func:`estimate_snr` — an SNR estimator using a derivative-based noise
6+
estimate and peak-based signal estimate.
7+
- :func:`estimate_kurtosis` — excess kurtosis of the trace distribution.
8+
9+
Notes
10+
-----
11+
- The SNR function design was inspired by the AIND-OPhys SLAP2 team.
12+
- Feed a dF/F preprocessed trace to :func:`estimate_snr`, as the peak
13+
height is interpreted from zero.
14+
- Default sampling frequency (``fps``) is 20 Hz; adjust it if your data
15+
differ.
16+
- NaNs are filled with the median of the trace prior to computation.
17+
18+
Example
19+
-------
20+
>>> import numpy as np
21+
>>> t = np.linspace(0, 10, 200, dtype=float) # 20 Hz sampling
22+
>>> y = 0.1 * np.sin(2 * np.pi * 1.0 * t) # small signal
23+
>>> snr, noise, peaks = estimate_snr(y) # doctest: +ELLIPSIS
24+
>>> isinstance(snr, float) and isinstance(noise, float)
25+
True
26+
>>> isinstance(peaks, np.ndarray)
27+
True
28+
>>> isinstance(estimate_kurtosis(y), float)
29+
True
30+
"""
31+
32+
from __future__ import annotations
33+
34+
import warnings
35+
from typing import Tuple
36+
37+
import numpy as np
38+
from numpy.typing import NDArray
39+
from scipy.signal import find_peaks
40+
from scipy.stats import kurtosis
41+
42+
__all__ = ["estimate_snr", "estimate_kurtosis"]
43+
44+
45+
def estimate_snr(
46+
trace: NDArray[np.floating], fps: float = 20.0
47+
) -> Tuple[float, float, NDArray[np.intp]]:
48+
"""
49+
Estimate the signal-to-noise ratio (SNR) of a 1D trace.
50+
51+
Parameters
52+
----------
53+
trace : numpy.ndarray
54+
1D input trace (e.g., dF/F). NaNs will be replaced with the
55+
median of ``trace`` before calculation.
56+
fps : float, optional
57+
Sampling frequency (frames per second), by default ``20.0``.
58+
59+
Returns
60+
-------
61+
snr : float
62+
Estimated signal-to-noise ratio (dimensionless).
63+
noise : float
64+
Estimated noise level computed from the first difference of the trace
65+
(standard deviation of ``diff(trace)`` divided by ``sqrt(2)``).
66+
peaks : numpy.ndarray
67+
Indices of detected peaks in the trace.
68+
69+
Notes
70+
-----
71+
- Noise is estimated from the derivative assuming white noise.
72+
- Signal is estimated from the 95th percentile of peak amplitudes.
73+
- Peak detection uses ``scipy.signal.find_peaks`` with sensible defaults.
74+
- If fewer than three peaks are found, ``snr`` and ``peaks`` are set to
75+
``NaN`` and a :class:`warnings.WarningMessage` is issued.
76+
"""
77+
# Replace NaNs with the median of the trace
78+
trace = np.nan_to_num(trace, nan=np.nanmedian(trace))
79+
80+
# Noise estimation based on derivative, assuming random noise
81+
dfdt = np.diff(trace)
82+
noise = float(
83+
np.std(dfdt) / np.sqrt(2)
84+
)
85+
86+
# Peak detection
87+
peaks, _ = find_peaks(
88+
trace,
89+
height=3 * noise, # Minimum peak height (adjust for your scale)
90+
distance=fps * 0.1, # Minimum number of samples between peaks
91+
prominence=0.05, # How much a peak stands out relative to neighbors
92+
width=5, # Optional: minimum width of peak
93+
)
94+
95+
if len(peaks) < 3:
96+
warnings.warn(
97+
"Not enough peaks found to estimate SNR. Returning NaN values.",
98+
RuntimeWarning,
99+
stacklevel=2,
100+
)
101+
return float("nan"), noise, np.array(np.nan)
102+
103+
# Signal estimate: 95th percentile of detected peak amplitudes
104+
amplitudes = np.sort(trace[peaks])
105+
signal = float(np.percentile(amplitudes, 95))
106+
107+
# Calculate SNR
108+
snr = float(signal / noise) if noise > 0 else float("inf")
109+
110+
return snr, noise, peaks
111+
112+
113+
def estimate_kurtosis(trace: NDArray[np.floating]) -> float:
114+
"""
115+
Compute the **excess kurtosis** of a 1D trace distribution.
116+
117+
Parameters
118+
----------
119+
trace : numpy.ndarray
120+
1D input trace. NaNs will be replaced with the median of ``trace``.
121+
122+
Returns
123+
-------
124+
float
125+
Excess kurtosis of the distribution (Fisher definition):
126+
- Normal distribution → 0.0
127+
- Leptokurtic → positive
128+
- Platykurtic → negative
129+
"""
130+
# Replace NaNs with the median of the trace
131+
trace = np.nan_to_num(trace, nan=np.nanmedian(trace))
132+
133+
# Excess kurtosis (normal distribution = 0)
134+
return float(kurtosis(trace, fisher=True, bias=False))

0 commit comments

Comments
 (0)