Skip to content

Commit b2ec20e

Browse files
committed
address reviews
1 parent e2cfc81 commit b2ec20e

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

mne/epochs.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4706,11 +4706,25 @@ def _concatenate_epochs_spectrum(epochs_list, add_offset=True):
47064706
if not np.array_equal(ep.freqs, ref.freqs):
47074707
raise ValueError(f"epochs_list[{ii}] freqs do not match epochs_list[0]")
47084708
_ensure_infos_match(ep.info, ref.info, f"epochs_list[{ii}]")
4709+
if ep.method != ref.method:
4710+
raise ValueError(
4711+
f"epochs_list[{ii}] method {ep.method!r} does not match "
4712+
f"epochs_list[0] method {ref.method!r}"
4713+
)
4714+
if ref.method == "multitaper":
4715+
ref_weights = getattr(ref, "weights", None)
4716+
ep_weights = getattr(ep, "weights", None)
4717+
if ref_weights is not None and ep_weights is not None:
4718+
if not np.array_equal(ep_weights, ref_weights):
4719+
raise ValueError(
4720+
f"epochs_list[{ii}] multitaper weights do not match "
4721+
f"epochs_list[0]"
4722+
)
47094723

47104724
data = np.concatenate([ep.data for ep in epochs_list], axis=0)
47114725

4712-
shift = np.int64(10 * ref.info["sfreq"])
4713-
events_offset = int(np.max(epochs_list[0].events[:, 0])) + shift
4726+
shift = len(ref.freqs)
4727+
events_offset = ref.events[-1, 0] + shift
47144728
all_events = [epochs_list[0].events.copy()]
47154729
for ep in epochs_list[1:]:
47164730
evs = ep.events.copy()

mne/time_frequency/tests/test_spectrum.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,3 +802,16 @@ def test_concatenate_epochs_spectrum():
802802
# passing a non-EpochsSpectrum should raise
803803
with pytest.raises(TypeError, match="must be an instance of EpochsSpectrum"):
804804
concatenate_epochs([sp1, epochs[:10]])
805+
806+
# mismatched method should raise
807+
sp_welch = epochs[:10].compute_psd(method="welch")
808+
with pytest.raises(ValueError, match="method"):
809+
concatenate_epochs([sp1, sp_welch])
810+
811+
# mismatched multitaper weights should raise
812+
sp_mt1 = epochs[:10].compute_psd(method="multitaper", bandwidth=2, output="complex")
813+
sp_mt2 = epochs[10:20].compute_psd(
814+
method="multitaper", bandwidth=20, output="complex"
815+
)
816+
with pytest.raises(ValueError, match="weights"):
817+
concatenate_epochs([sp_mt1, sp_mt2])

0 commit comments

Comments
 (0)