Skip to content

Commit bd483af

Browse files
committed
Update snr_kurtosis.py
lint
1 parent 7da0734 commit bd483af

1 file changed

Lines changed: 84 additions & 37 deletions

File tree

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,130 @@
1-
import numpy as np
1+
"""
2+
Utilities for signal quality metrics on 1D fluorescence traces.
3+
4+
This module provides:
5+
- :func:`estimate_snr` — an SNR estimator using a derivative-based noise
6+
estimate and peak-based signal estimate.
7+
- :func:`estimate_kurtosis` — excess kurtosis of the trace distribution.
8+
9+
Notes
10+
-----
11+
- The SNR function design was inspired by the AIND-OPhys SLAP2 team.
12+
- Feed a dF/F preprocessed trace to :func:`estimate_snr`, as the peak
13+
height is interpreted from zero.
14+
- Default sampling frequency (``fps``) is 20 Hz; adjust it if your data
15+
differ.
16+
- NaNs are filled with the median of the trace prior to computation.
17+
18+
Example
19+
-------
20+
>>> import numpy as np
21+
>>> t = np.linspace(0, 10, 200, dtype=float) # 20 Hz sampling
22+
>>> y = 0.1 * np.sin(2 * np.pi * 1.0 * t) # small signal
23+
>>> snr, noise, peaks = estimate_snr(y) # doctest: +ELLIPSIS
24+
>>> isinstance(snr, float) and isinstance(noise, float)
25+
True
26+
>>> isinstance(peaks, np.ndarray)
27+
True
28+
>>> isinstance(estimate_kurtosis(y), float)
29+
True
30+
"""
31+
32+
from __future__ import annotations
33+
234
import warnings
35+
from typing import Tuple
36+
37+
import numpy as np
38+
from numpy.typing import NDArray
339
from scipy.signal import find_peaks
440
from scipy.stats import kurtosis
541

6-
# -snr function was inspired by AIND-OPhys SLAP2 team.
7-
# -for `trace` feed a dF/F preprocessed trace as the function calculate the peak values from zero
8-
# -so far all FIP sampling frequency has been 20Hz. Modify it when using a different fps.
9-
# -given the peak finding won't nan, nan will be filled with median of the trace.
42+
__all__ = ["estimate_snr", "estimate_kurtosis"]
1043

11-
def estimate_snr(trace, fps=20.0):
44+
45+
def estimate_snr(trace: NDArray[np.floating], fps: float = 20.0) -> Tuple[float, float, NDArray[np.intp]]:
1246
"""
13-
Estimate the signal-to-noise ratio (SNR) of a trace.
47+
Estimate the signal-to-noise ratio (SNR) of a 1D trace.
1448
1549
Parameters
1650
----------
17-
trace : np.ndarray
18-
The input trace.
19-
fps : float
20-
Frames per second of the trace.
51+
trace : numpy.ndarray
52+
1D input trace (e.g., dF/F). NaNs will be replaced with the
53+
median of ``trace`` before calculation.
54+
fps : float, optional
55+
Sampling frequency (frames per second), by default ``20.0``.
2156
2257
Returns
2358
-------
2459
snr : float
25-
Estimated signal-to-noise ratio.
60+
Estimated signal-to-noise ratio (dimensionless).
2661
noise : float
27-
Estimated noise level.
28-
peaks : np.ndarray
62+
Estimated noise level computed from the first difference of the trace
63+
(standard deviation of ``diff(trace)`` divided by ``sqrt(2)``).
64+
peaks : numpy.ndarray
2965
Indices of detected peaks in the trace.
66+
67+
Notes
68+
-----
69+
- Noise is estimated from the derivative assuming white noise.
70+
- Signal is estimated from the 95th percentile of peak amplitudes.
71+
- Peak detection uses ``scipy.signal.find_peaks`` with sensible defaults.
72+
- If fewer than three peaks are found, ``snr`` and ``peaks`` are set to
73+
``NaN`` and a :class:`warnings.WarningMessage` is issued.
3074
"""
3175
# Replace NaNs with the median of the trace
3276
trace = np.nan_to_num(trace, nan=np.nanmedian(trace))
3377

3478
# Noise estimation based on derivative, assuming random noise
3579
dfdt = np.diff(trace)
36-
noise = np.std(dfdt) / np.sqrt(2)
80+
noise = float(np.std(dfdt) / np.sqrt(2))
3781

38-
# Estimate signal as the third peak using scipy's find_peaks
82+
# Peak detection
3983
peaks, _ = find_peaks(
4084
trace,
41-
height=3 * noise, # Minimum peak height (adjust based on your signal scale)
42-
distance=fps * 0.1, # Minimum number of samples between peaks
43-
prominence=0.05, # How much a peak stands out relative to neighbors
44-
width=5 # Optional: minimum width of peak
85+
height=3 * noise, # Minimum peak height (adjust for your scale)
86+
distance=fps * 0.1, # Minimum number of samples between peaks
87+
prominence=0.05, # How much a peak stands out relative to neighbors
88+
width=5, # Optional: minimum width of peak
4589
)
4690

4791
if len(peaks) < 3:
48-
# Warning if not enough peaks are found
49-
warnings.warn("Not enough peaks found to estimate SNR. Returning NaN values.")
50-
return np.nan, noise, np.nan
51-
52-
# Take the 95th percentile of peak amplitudes as the signal
92+
warnings.warn(
93+
"Not enough peaks found to estimate SNR. Returning NaN values.",
94+
RuntimeWarning,
95+
stacklevel=2,
96+
)
97+
return float("nan"), noise, np.array(np.nan)
98+
99+
# Signal estimate: 95th percentile of detected peak amplitudes
53100
amplitudes = np.sort(trace[peaks])
54-
signal = np.percentile(amplitudes, 95)
101+
signal = float(np.percentile(amplitudes, 95))
55102

56103
# Calculate SNR
57-
snr = signal / noise
104+
snr = float(signal / noise) if noise > 0 else float("inf")
58105

59106
return snr, noise, peaks
60107

61108

62-
def estimate_kurtosis(trace):
109+
def estimate_kurtosis(trace: NDArray[np.floating]) -> float:
63110
"""
64-
Estimate the kurtosis of a trace distribution.
111+
Compute the **excess kurtosis** of a 1D trace distribution.
65112
66113
Parameters
67114
----------
68-
trace : np.ndarray
69-
The input trace.
115+
trace : numpy.ndarray
116+
1D input trace. NaNs will be replaced with the median of ``trace``.
70117
71118
Returns
72119
-------
73-
kurt : float
74-
Estimated excess kurtosis of the distribution.
75-
(Normal distribution = 0, leptokurtic > 0, platykurtic < 0)
120+
float
121+
Excess kurtosis of the distribution (Fisher definition):
122+
- Normal distribution → 0.0
123+
- Leptokurtic → positive
124+
- Platykurtic → negative
76125
"""
77126
# Replace NaNs with the median of the trace
78127
trace = np.nan_to_num(trace, nan=np.nanmedian(trace))
79128

80129
# Excess kurtosis (normal distribution = 0)
81-
kurt = kurtosis(trace, fisher=True, bias=False)
82-
83-
return kurt
130+
return float(kurtosis(trace, fisher=True, bias=False))

0 commit comments

Comments
 (0)