|
2 | 2 | # License: BSD-3-Clause |
3 | 3 | # Copyright the MNE-Python contributors. |
4 | 4 |
|
| 5 | +import warnings |
| 6 | + |
5 | 7 | import numpy as np |
6 | 8 | import pytest |
| 9 | +from numpy.testing import assert_allclose |
7 | 10 |
|
| 11 | +from mne import create_info |
8 | 12 | from mne.datasets import testing |
9 | 13 | from mne.datasets.testing import data_path |
10 | | -from mne.io import BaseRaw, read_raw_fif, read_raw_nirx, read_raw_snirf |
| 14 | +from mne.io import BaseRaw, RawArray, read_raw_fif, read_raw_nirx, read_raw_snirf |
11 | 15 | from mne.preprocessing.nirs import ( |
12 | 16 | _channel_frequencies, |
13 | 17 | beer_lambert_law, |
14 | 18 | optical_density, |
| 19 | + source_detector_distances, |
15 | 20 | ) |
| 21 | +from mne.preprocessing.nirs._beer_lambert_law import _get_sd_distances |
16 | 22 | from mne.utils import _validate_type |
17 | 23 |
|
18 | 24 | testing_path = data_path(download=False) |
@@ -112,3 +118,83 @@ def test_beer_lambert_v_matlab(): |
112 | 118 | + matlab_data["type"][idx] |
113 | 119 | ) |
114 | 120 | assert raw.info["ch_names"][idx] == matlab_name |
| 121 | + |
| 122 | + |
| 123 | +def test_beer_lambert_sd_distances(): |
| 124 | + """Test Beer-Lambert conversion with explicit source-detector distances.""" |
| 125 | + data = np.array( |
| 126 | + [[0.1, 0.2, 0.3], [0.15, 0.25, 0.35], [0.4, 0.5, 0.6], [0.45, 0.55, 0.65]] |
| 127 | + ) |
| 128 | + # Ch names chosen to test reordered indices |
| 129 | + ch_names = ["S1_D1 760", "S1_D1 850", "S10_D10 760", "S10_D10 850"] |
| 130 | + |
| 131 | + # Case 1: valid locations, sd_distances=None |
| 132 | + raw = RawArray(data, create_info(ch_names, sfreq=1.0, ch_types="fnirs_od")) |
| 133 | + sd_distances = [0.03, 0.03, 0.03, 0.03] |
| 134 | + for idx, (freq, distance) in enumerate(zip([760, 850, 760, 850], sd_distances)): |
| 135 | + raw.info["chs"][idx]["loc"][3:6] = [0.0, 0.0, 0.0] |
| 136 | + raw.info["chs"][idx]["loc"][6:9] = [distance, 0.0, 0.0] |
| 137 | + raw.info["chs"][idx]["loc"][9] = freq |
| 138 | + expected = beer_lambert_law(raw) |
| 139 | + |
| 140 | + # Case 2: valid locations, sd_distances=<arr> |
| 141 | + with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"): |
| 142 | + actual = beer_lambert_law(raw, sd_distances=sd_distances) |
| 143 | + assert actual.ch_names == expected.ch_names |
| 144 | + assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0) |
| 145 | + |
| 146 | + # Case 3: no locations, sd_distances=None |
| 147 | + for idx in range(len(raw.info["chs"])): |
| 148 | + raw.info["chs"][idx]["loc"][3:9] = np.nan |
| 149 | + assert np.isnan(source_detector_distances(raw.info)).all() |
| 150 | + with pytest.raises( |
| 151 | + ValueError, match=r"(?i)source-detector distances are all zero or NaN" |
| 152 | + ): |
| 153 | + beer_lambert_law(raw) |
| 154 | + |
| 155 | + # Case 4: no locations, sd_distances=<arr> |
| 156 | + actual = beer_lambert_law(raw, sd_distances=sd_distances) |
| 157 | + assert actual.ch_names == expected.ch_names |
| 158 | + assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0) |
| 159 | + |
| 160 | + # Case 5: no locations, sd_distances=<scalar> |
| 161 | + actual = beer_lambert_law(raw, sd_distances=sd_distances[0]) |
| 162 | + assert actual.ch_names == expected.ch_names |
| 163 | + assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0) |
| 164 | + |
| 165 | + |
| 166 | +def test_get_sd_distances(): |
| 167 | + """Test source-detector distance selection and validation.""" |
| 168 | + raw = RawArray( |
| 169 | + np.zeros((4, 3)), |
| 170 | + create_info( |
| 171 | + ["S1_D1 760", "S1_D1 850", "S2_D2 760", "S2_D2 850"], 1.0, "fnirs_od" |
| 172 | + ), |
| 173 | + ) |
| 174 | + expected = np.array([0.03, 0.03, 0.04, 0.04]) |
| 175 | + for idx, (freq, distance) in enumerate(zip([760, 850, 760, 850], expected)): |
| 176 | + raw.info["chs"][idx]["loc"][3:6] = [0.0, 0.0, 0.0] |
| 177 | + raw.info["chs"][idx]["loc"][6:9] = [distance, 0.0, 0.0] |
| 178 | + raw.info["chs"][idx]["loc"][9] = freq |
| 179 | + |
| 180 | + assert_allclose(_get_sd_distances(raw, None), expected, rtol=1e-12, atol=0) |
| 181 | + with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"): |
| 182 | + assert_allclose(_get_sd_distances(raw, expected), expected, rtol=1e-12, atol=0) |
| 183 | + with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"): |
| 184 | + assert_allclose( |
| 185 | + _get_sd_distances(raw, 0.05), np.full(4, 0.05), rtol=1e-12, atol=0 |
| 186 | + ) |
| 187 | + |
| 188 | + for idx in range(len(raw.info["chs"])): |
| 189 | + raw.info["chs"][idx]["loc"][3:9] = np.nan |
| 190 | + with warnings.catch_warnings(record=True) as caught: |
| 191 | + warnings.simplefilter("always") |
| 192 | + assert_allclose(_get_sd_distances(raw, expected), expected, rtol=1e-12, atol=0) |
| 193 | + assert len(caught) == 0 |
| 194 | + |
| 195 | + with pytest.raises(ValueError, match=r"1D array-like"): |
| 196 | + _get_sd_distances(raw, np.ones((2, 2))) |
| 197 | + with pytest.raises(ValueError, match=r"length matching"): |
| 198 | + _get_sd_distances(raw, [0.03, 0.03]) |
| 199 | + with pytest.raises(TypeError, match=r"sd_distances"): |
| 200 | + _get_sd_distances(raw, "foo") |
0 commit comments