|
2 | 2 | # License: BSD-3-Clause |
3 | 3 | # Copyright the MNE-Python contributors. |
4 | 4 |
|
| 5 | + |
5 | 6 | import numpy as np |
6 | 7 | import pytest |
| 8 | +from numpy.testing import assert_allclose |
7 | 9 |
|
| 10 | +from mne import create_info |
8 | 11 | from mne.datasets import testing |
9 | 12 | from mne.datasets.testing import data_path |
10 | | -from mne.io import BaseRaw, read_raw_fif, read_raw_nirx, read_raw_snirf |
| 13 | +from mne.io import BaseRaw, RawArray, read_raw_fif, read_raw_nirx, read_raw_snirf |
11 | 14 | from mne.preprocessing.nirs import ( |
12 | 15 | _channel_frequencies, |
13 | 16 | beer_lambert_law, |
14 | 17 | optical_density, |
| 18 | + source_detector_distances, |
15 | 19 | ) |
| 20 | +from mne.preprocessing.nirs._beer_lambert_law import _get_sd_distances |
16 | 21 | from mne.utils import _validate_type |
17 | 22 |
|
18 | 23 | testing_path = data_path(download=False) |
@@ -112,3 +117,80 @@ def test_beer_lambert_v_matlab(): |
112 | 117 | + matlab_data["type"][idx] |
113 | 118 | ) |
114 | 119 | assert raw.info["ch_names"][idx] == matlab_name |
| 120 | + |
| 121 | + |
| 122 | +def test_beer_lambert_sd_distances(): |
| 123 | + """Test Beer-Lambert conversion with explicit source-detector distances.""" |
| 124 | + data = np.array( |
| 125 | + [[0.1, 0.2, 0.3], [0.15, 0.25, 0.35], [0.4, 0.5, 0.6], [0.45, 0.55, 0.65]] |
| 126 | + ) |
| 127 | + # Ch names chosen to test reordered indices |
| 128 | + ch_names = ["S1_D1 760", "S1_D1 850", "S10_D10 760", "S10_D10 850"] |
| 129 | + |
| 130 | + # Case 1: valid locations, sd_distances=None |
| 131 | + raw = RawArray(data, create_info(ch_names, sfreq=1.0, ch_types="fnirs_od")) |
| 132 | + sd_distances = [0.03, 0.03, 0.03, 0.03] |
| 133 | + for idx, (freq, distance) in enumerate(zip([760, 850, 760, 850], sd_distances)): |
| 134 | + raw.info["chs"][idx]["loc"][3:6] = [0.0, 0.0, 0.0] |
| 135 | + raw.info["chs"][idx]["loc"][6:9] = [distance, 0.0, 0.0] |
| 136 | + raw.info["chs"][idx]["loc"][9] = freq |
| 137 | + expected = beer_lambert_law(raw) |
| 138 | + |
| 139 | + # Case 2: valid locations, sd_distances=<arr> |
| 140 | + with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"): |
| 141 | + actual = beer_lambert_law(raw, sd_distances=sd_distances) |
| 142 | + assert actual.ch_names == expected.ch_names |
| 143 | + assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0) |
| 144 | + |
| 145 | + # Case 3: no locations, sd_distances=None |
| 146 | + for idx in range(len(raw.info["chs"])): |
| 147 | + raw.info["chs"][idx]["loc"][3:9] = np.nan |
| 148 | + assert np.isnan(source_detector_distances(raw.info)).all() |
| 149 | + with pytest.raises( |
| 150 | + ValueError, match=r"(?i)source-detector distances are all zero or NaN" |
| 151 | + ): |
| 152 | + beer_lambert_law(raw) |
| 153 | + |
| 154 | + # Case 4: no locations, sd_distances=<arr> |
| 155 | + actual = beer_lambert_law(raw, sd_distances=sd_distances) |
| 156 | + assert actual.ch_names == expected.ch_names |
| 157 | + assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0) |
| 158 | + |
| 159 | + # Case 5: no locations, sd_distances=<scalar> |
| 160 | + actual = beer_lambert_law(raw, sd_distances=sd_distances[0]) |
| 161 | + assert actual.ch_names == expected.ch_names |
| 162 | + assert_allclose(actual.get_data(), expected.get_data(), rtol=1e-12, atol=0) |
| 163 | + |
| 164 | + |
| 165 | +def test_get_sd_distances(): |
| 166 | + """Test source-detector distance selection and validation.""" |
| 167 | + raw = RawArray( |
| 168 | + np.zeros((4, 3)), |
| 169 | + create_info( |
| 170 | + ["S1_D1 760", "S1_D1 850", "S2_D2 760", "S2_D2 850"], 1.0, "fnirs_od" |
| 171 | + ), |
| 172 | + ) |
| 173 | + expected = np.array([0.03, 0.03, 0.04, 0.04]) |
| 174 | + for idx, (freq, distance) in enumerate(zip([760, 850, 760, 850], expected)): |
| 175 | + raw.info["chs"][idx]["loc"][3:6] = [0.0, 0.0, 0.0] |
| 176 | + raw.info["chs"][idx]["loc"][6:9] = [distance, 0.0, 0.0] |
| 177 | + raw.info["chs"][idx]["loc"][9] = freq |
| 178 | + |
| 179 | + assert_allclose(_get_sd_distances(raw, None), expected, rtol=1e-12, atol=0) |
| 180 | + with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"): |
| 181 | + assert_allclose(_get_sd_distances(raw, expected), expected, rtol=1e-12, atol=0) |
| 182 | + with pytest.warns(RuntimeWarning, match=r"(?i)will be overridden"): |
| 183 | + assert_allclose( |
| 184 | + _get_sd_distances(raw, 0.05), np.full(4, 0.05), rtol=1e-12, atol=0 |
| 185 | + ) |
| 186 | + |
| 187 | + for idx in range(len(raw.info["chs"])): |
| 188 | + raw.info["chs"][idx]["loc"][3:9] = np.nan |
| 189 | + assert_allclose(_get_sd_distances(raw, expected), expected, rtol=1e-12, atol=0) |
| 190 | + |
| 191 | + with pytest.raises(ValueError, match=r"1D array-like"): |
| 192 | + _get_sd_distances(raw, np.ones((2, 2))) |
| 193 | + with pytest.raises(ValueError, match=r"length matching"): |
| 194 | + _get_sd_distances(raw, [0.03, 0.03]) |
| 195 | + with pytest.raises(TypeError, match=r"sd_distances"): |
| 196 | + _get_sd_distances(raw, "foo") |
0 commit comments