Skip to content

Commit 58c8ab2

Browse files
committed
Add sd_distances=None to MBLL for manual override of source detector locations
1 parent 423f8d5 commit 58c8ab2

2 files changed

Lines changed: 137 additions & 4 deletions

File tree

mne/preprocessing/nirs/_beer_lambert_law.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@
1111
from ..._fiff.constants import FIFF
1212
from ...io import BaseRaw
1313
from ...utils import _validate_type, pinv, warn
14-
from ..nirs import _channel_frequencies, _validate_nirs_info, source_detector_distances
14+
from ..nirs import (
15+
_channel_frequencies,
16+
_has_source_detector_distances,
17+
_validate_nirs_info,
18+
source_detector_distances,
19+
)
1520

1621

17-
def beer_lambert_law(raw, ppf=6.0):
22+
def beer_lambert_law(raw, ppf=6.0, sd_distances=None):
1823
r"""Convert NIRS optical density data to haemoglobin concentration.
1924
2025
Parameters
@@ -26,6 +31,10 @@ def beer_lambert_law(raw, ppf=6.0):
2631
2732
.. versionchanged:: 1.7
2833
Support for different factors for the two wavelengths.
34+
sd_distances : array-like | float | None
35+
Source-detector distances in meters. If ``None``, distances are read
36+
from ``raw.info['chs']``. If array-like, the values must have a distance
37+
for each channel, matching the order in ``info['chs']``.
2938
3039
Returns
3140
-------
@@ -70,9 +79,14 @@ def beer_lambert_law(raw, ppf=6.0):
7079
)
7180

7281
abs_coef = _load_absorption(unique_freqs) # shape (n_wavelengths, 2)
73-
distances = source_detector_distances(raw.info, picks="all")
82+
distances = _get_sd_distances(raw, sd_distances)
7483
bad = ~np.isfinite(distances[picks])
7584
bad |= distances[picks] <= 0
85+
if bad.all():
86+
raise ValueError(
87+
"Source-detector distances are all zero or NaN. Consider setting a "
88+
"montage with raw.set_montage or providing sd_distances."
89+
)
7690
if bad.any():
7791
warn(
7892
"Source-detector distances are zero or NaN, some resulting "
@@ -129,6 +143,39 @@ def beer_lambert_law(raw, ppf=6.0):
129143
return raw
130144

131145

146+
def _get_sd_distances(raw, sd_distances):
147+
"""Get source-detector distances for each channel.
148+
149+
Returns
150+
-------
151+
dists : array of float
152+
Array containing distances in meters.
153+
Of shape equal to number of channels.
154+
"""
155+
if sd_distances is None:
156+
# picks="all" used here instead of picks s.t. distance indices match raw
157+
return source_detector_distances(raw.info, picks="all")
158+
elif _has_source_detector_distances(raw.info, picks="all"):
159+
warn("Source-detector distances in raw.info[] will be overridden")
160+
_validate_type(sd_distances, ("numeric", "array-like"), "sd_distances")
161+
sd_distances = np.array(sd_distances, float)
162+
n_channels = len(raw.info["chs"])
163+
if sd_distances.ndim == 0:
164+
return np.full(n_channels, sd_distances)
165+
if sd_distances.ndim != 1:
166+
raise ValueError(
167+
"sd_distances must be a float or a 1D array-like, got "
168+
f"shape {sd_distances.shape}"
169+
)
170+
if len(sd_distances) == n_channels:
171+
return sd_distances
172+
raise ValueError(
173+
"sd_distances must be a float or an array-like with length matching "
174+
f"the len(raw.info['chs']) ({n_channels}), "
175+
f"got length {len(sd_distances)}"
176+
)
177+
178+
132179
def _load_absorption(freqs):
133180
"""Load molar extinction coefficients."""
134181
# Data from https://omlc.org/spectra/hemoglobin/summary.html

mne/preprocessing/nirs/tests/test_beer_lambert_law.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,23 @@
22
# License: BSD-3-Clause
33
# Copyright the MNE-Python contributors.
44

5+
import warnings
6+
57
import numpy as np
68
import pytest
9+
from numpy.testing import assert_allclose
710

11+
from mne import create_info
812
from mne.datasets import testing
913
from mne.datasets.testing import data_path
10-
from mne.io import BaseRaw, read_raw_fif, read_raw_nirx, read_raw_snirf
14+
from mne.io import BaseRaw, RawArray, read_raw_fif, read_raw_nirx, read_raw_snirf
1115
from mne.preprocessing.nirs import (
1216
_channel_frequencies,
1317
beer_lambert_law,
1418
optical_density,
19+
source_detector_distances,
1520
)
21+
from mne.preprocessing.nirs._beer_lambert_law import _get_sd_distances
1622
from mne.utils import _validate_type
1723

1824
testing_path = data_path(download=False)
@@ -112,3 +118,83 @@ def test_beer_lambert_v_matlab():
112118
+ matlab_data["type"][idx]
113119
)
114120
assert raw.info["ch_names"][idx] == matlab_name
121+
122+
123+
def test_beer_lambert_sd_distances():
124+
"""Test Beer-Lambert conversion with explicit source-detector distances."""
125+
data = np.array(
126+
[[0.1, 0.2, 0.3], [0.15, 0.25, 0.35], [0.4, 0.5, 0.6], [0.45, 0.55, 0.65]]
127+
)
128+
# Ch names chosen to test reordered indices
129+
ch_names = ["S1_D1 760", "S1_D1 850", "S10_D10 760", "S10_D10 850"]
130+
131+
# Case 1: valid locations, sd_distances=None
132+
raw = RawArray(data, create_info(ch_names, sfreq=1.0, ch_types="fnirs_od"))
133+
sd_distances = [0.03, 0.03, 0.03, 0.03]
134+
for idx, (freq, distance) in enumerate(zip([760, 850, 760, 850], sd_distances)):
135+
raw.info["chs"][idx]["loc"][3:6] = [0.0, 0.0, 0.0]
136+
raw.info["chs"][idx]["loc"][6:9] = [distance, 0.0, 0.0]
137+
raw.info["chs"][idx]["loc"][9] = freq
138+
expected = beer_lambert_law(raw)
139+
140+
# Case 2: valid locations, sd_distances=<arr>
141+
with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"):
142+
actual = beer_lambert_law(raw, sd_distances=sd_distances)
143+
assert actual.ch_names == expected.ch_names
144+
assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0)
145+
146+
# Case 3: no locations, sd_distances=None
147+
for idx in range(len(raw.info["chs"])):
148+
raw.info["chs"][idx]["loc"][3:9] = np.nan
149+
assert np.isnan(source_detector_distances(raw.info)).all()
150+
with pytest.raises(
151+
ValueError, match=r"(?i)source-detector distances are all zero or NaN"
152+
):
153+
beer_lambert_law(raw)
154+
155+
# Case 4: no locations, sd_distances=<arr>
156+
actual = beer_lambert_law(raw, sd_distances=sd_distances)
157+
assert actual.ch_names == expected.ch_names
158+
assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0)
159+
160+
# Case 5: no locations, sd_distances=<scalar>
161+
actual = beer_lambert_law(raw, sd_distances=sd_distances[0])
162+
assert actual.ch_names == expected.ch_names
163+
assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0)
164+
165+
166+
def test_get_sd_distances():
167+
"""Test source-detector distance selection and validation."""
168+
raw = RawArray(
169+
np.zeros((4, 3)),
170+
create_info(
171+
["S1_D1 760", "S1_D1 850", "S2_D2 760", "S2_D2 850"], 1.0, "fnirs_od"
172+
),
173+
)
174+
expected = np.array([0.03, 0.03, 0.04, 0.04])
175+
for idx, (freq, distance) in enumerate(zip([760, 850, 760, 850], expected)):
176+
raw.info["chs"][idx]["loc"][3:6] = [0.0, 0.0, 0.0]
177+
raw.info["chs"][idx]["loc"][6:9] = [distance, 0.0, 0.0]
178+
raw.info["chs"][idx]["loc"][9] = freq
179+
180+
assert_allclose(_get_sd_distances(raw, None), expected, rtol=1e-12, atol=0)
181+
with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"):
182+
assert_allclose(_get_sd_distances(raw, expected), expected, rtol=1e-12, atol=0)
183+
with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"):
184+
assert_allclose(
185+
_get_sd_distances(raw, 0.05), np.full(4, 0.05), rtol=1e-12, atol=0
186+
)
187+
188+
for idx in range(len(raw.info["chs"])):
189+
raw.info["chs"][idx]["loc"][3:9] = np.nan
190+
with warnings.catch_warnings(record=True) as caught:
191+
warnings.simplefilter("always")
192+
assert_allclose(_get_sd_distances(raw, expected), expected, rtol=1e-12, atol=0)
193+
assert len(caught) == 0
194+
195+
with pytest.raises(ValueError, match=r"1D array-like"):
196+
_get_sd_distances(raw, np.ones((2, 2)))
197+
with pytest.raises(ValueError, match=r"length matching"):
198+
_get_sd_distances(raw, [0.03, 0.03])
199+
with pytest.raises(TypeError, match=r"sd_distances"):
200+
_get_sd_distances(raw, "foo")

0 commit comments

Comments
 (0)