|
10 | 10 | from numpy.testing import assert_allclose, assert_array_equal |
11 | 11 |
|
12 | 12 | import mne.channels.channels |
13 | | -from mne import Epochs, pick_channels, pick_types, read_events |
| 13 | +from mne import Epochs, create_info, pick_channels, pick_types, read_events |
14 | 14 | from mne._fiff.constants import FIFF |
15 | 15 | from mne._fiff.proj import _has_eeg_average_ref_proj |
16 | 16 | from mne.channels import make_dig_montage, make_standard_montage |
@@ -333,6 +333,44 @@ def test_interpolation_nirs(): |
333 | 333 | assert raw_haemo.info["bads"] == [] |
334 | 334 |
|
335 | 335 |
|
| 336 | +def test_interpolation_nirs_reordered_picks(): |
| 337 | + """Test NIRS interpolation uses the closest donor in raw channel space.""" |
| 338 | + ch_names = [ |
| 339 | + "S1_D1 760", |
| 340 | + "S1_D1 850", |
| 341 | + "S2_D2 760", |
| 342 | + "S2_D2 850", |
| 343 | + "S3_D3 760", |
| 344 | + "S3_D3 850", |
| 345 | + "S10_D10 760", |
| 346 | + "S10_D10 850", |
| 347 | + ] |
| 348 | + info = create_info(ch_names, sfreq=1.0, ch_types=["fnirs_cw_amplitude"] * 8) |
| 349 | + pair_positions = { |
| 350 | + "S1_D1": (0.009, 0.0, 0.0), |
| 351 | + "S2_D2": (0.010, 0.0, 0.0), |
| 352 | + "S3_D3": (0.030, 0.0, 0.0), |
| 353 | + "S10_D10": (0.040, 0.0, 0.0), |
| 354 | + } |
| 355 | + for idx, ch in enumerate(info["chs"]): |
| 356 | + pair = ch["ch_name"].rsplit(" ", 1)[0] |
| 357 | + ch["loc"][:3] = pair_positions[pair] |
| 358 | + ch["loc"][9] = 760.0 if idx % 2 == 0 else 850.0 |
| 359 | + data = np.arange(len(ch_names), dtype=float).reshape(-1, 1) |
| 360 | + data = np.repeat(data, 5, axis=1) |
| 361 | + raw = RawArray(data, info, verbose=False) |
| 362 | + raw.info["bads"] = ["S2_D2 760", "S2_D2 850"] |
| 363 | + |
| 364 | + raw.interpolate_bads( |
| 365 | + method=dict(fnirs="nearest"), origin=(0.0, 0.0, 0.0), verbose=False |
| 366 | + ) |
| 367 | + |
| 368 | + # Bad S2_D2 should copy from the nearest good pair, S1_D1. |
| 369 | + picks_bad = pick_channels(raw.ch_names, ["S2_D2 760", "S2_D2 850"], exclude=[]) |
| 370 | + picks_want = pick_channels(raw.ch_names, ["S1_D1 760", "S1_D1 850"], exclude=[]) |
| 371 | + assert_allclose(raw.get_data(picks=picks_bad), raw.get_data(picks=picks_want)) |
| 372 | + |
| 373 | + |
336 | 374 | @testing.requires_testing_data |
337 | 375 | def test_interpolation_ecog(): |
338 | 376 | """Test interpolation for ECoG.""" |
|
0 commit comments