diff --git a/doc/changes/dev/13894.bugfix.rst b/doc/changes/dev/13894.bugfix.rst new file mode 100644 index 00000000000..c6af5316f06 --- /dev/null +++ b/doc/changes/dev/13894.bugfix.rst @@ -0,0 +1 @@ +Fix bug where ``picks`` was ignored in :meth:`mne.Epochs.apply_function` when ``channel_wise=False``, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/epochs.py b/mne/epochs.py index 03fff949572..8aaf1316305 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -2045,7 +2045,7 @@ def apply_function( for run_idx, ch_idx in enumerate(picks): self._data[:, ch_idx, :] = data_picks_new[run_idx] else: - self._data = _check_fun(fun, data_in, **kwargs) + self._data[:, picks, :] = _check_fun(fun, data_in[:, picks, :], **kwargs) return self diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 91c5f902ac8..1a66dc03eec 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -4832,20 +4832,21 @@ def test_apply_function(): info = mne.create_info(n_channels, 1000.0, "eeg") epochs = mne.EpochsArray(data, info, events) data_epochs = epochs.get_data() + picks = np.arange(3) + non_picks = np.arange(3, n_channels) # apply_function to all channels at once def fun(data): """Reverse channel order without changing values.""" return np.eye(data.shape[1])[::-1] @ data - want = data_epochs[:, ::-1] - got = epochs.apply_function(fun, channel_wise=False).get_data() + want = np.concatenate( + [data_epochs[:, picks][:, ::-1], data_epochs[:, non_picks]], axis=1 + ) # only reverse channel order of picks + got = epochs.apply_function(fun, picks=picks, channel_wise=False).get_data() assert_array_equal(want, got) # apply_function channel-wise (to first 3 channels) by replacing with mean - picks = np.arange(3) - non_picks = np.arange(3, n_channels) - def fun(data): return np.full_like(data, data.mean())