Skip to content

Commit 9438956

Browse files
authored
Fix picks being ignored in Epochs.apply_function() (#13894)
1 parent 893b784 commit 9438956

3 files changed

Lines changed: 8 additions & 6 deletions

File tree

doc/changes/dev/13894.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug where ``picks`` was ignored in :meth:`mne.Epochs.apply_function` when ``channel_wise=False``, by `Thomas Binns`_.

mne/epochs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2045,7 +2045,7 @@ def apply_function(
20452045
for run_idx, ch_idx in enumerate(picks):
20462046
self._data[:, ch_idx, :] = data_picks_new[run_idx]
20472047
else:
2048-
self._data = _check_fun(fun, data_in, **kwargs)
2048+
self._data[:, picks, :] = _check_fun(fun, data_in[:, picks, :], **kwargs)
20492049

20502050
return self
20512051

mne/tests/test_epochs.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4832,20 +4832,21 @@ def test_apply_function():
48324832
info = mne.create_info(n_channels, 1000.0, "eeg")
48334833
epochs = mne.EpochsArray(data, info, events)
48344834
data_epochs = epochs.get_data()
4835+
picks = np.arange(3)
4836+
non_picks = np.arange(3, n_channels)
48354837

48364838
# apply_function to all channels at once
48374839
def fun(data):
48384840
"""Reverse channel order without changing values."""
48394841
return np.eye(data.shape[1])[::-1] @ data
48404842

4841-
want = data_epochs[:, ::-1]
4842-
got = epochs.apply_function(fun, channel_wise=False).get_data()
4843+
want = np.concatenate(
4844+
[data_epochs[:, picks][:, ::-1], data_epochs[:, non_picks]], axis=1
4845+
) # only reverse channel order of picks
4846+
got = epochs.apply_function(fun, picks=picks, channel_wise=False).get_data()
48434847
assert_array_equal(want, got)
48444848

48454849
# apply_function channel-wise (to first 3 channels) by replacing with mean
4846-
picks = np.arange(3)
4847-
non_picks = np.arange(3, n_channels)
4848-
48494850
def fun(data):
48504851
return np.full_like(data, data.mean())
48514852

0 commit comments

Comments
 (0)