Skip to content

Commit 535a4b3

Browse files
Fix WithinSession/WithinSubjectSplitter overwriting explicit n_splits in cv_kwargs (#1107)
* Initial plan * Fix cv_kwargs n_splits override in WithinSession/WithinSubject splitters * fix(evaluations): honour n_splits in Within* evaluations and make WithinSubjectSplitter reproducible - WithinSessionEvaluation/WithinSubjectEvaluation now map the base-class n_splits to the inner n_folds instead of hardcoding 5 folds, matching CrossSubjectEvaluation. - WithinSubjectSplitter.split() reseeds its RNG per call (shared across subjects to preserve the legacy fold sequence) so repeated calls with a fixed random_state are reproducible. - Correct the cv_kwargs docstrings/changelog: only n_splits can be passed through cv_kwargs (shuffle/random_state are named constructor params). * style: use dict literal in test_within_n_splits_drives_n_folds (ruff C408) --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Bru <b.aristimunha@gmail.com>
1 parent 24800f9 commit 535a4b3

5 files changed

Lines changed: 79 additions & 6 deletions

File tree

docs/source/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ Bugs
7171
- Fix Windows download path sanitization that changed absolute paths like ``C:\data`` into relative ``C-\data`` paths (:gh:`1079` by `Anton Andreev`_).
7272
- Fix missing electrode positions (NaN xyz) in six motor-imagery datasets so topographic maps, interpolation, and spatial methods work: :class:`moabb.datasets.Forenzo2023` and :class:`moabb.datasets.GuttmannFlury2025_MI`/``_ME`` normalize Neuroscan ALL_CAPS labels and apply ``standard_1005`` (CB1/CB2 kept as ``misc``); :class:`moabb.datasets.Dreyer2023` falls back to ``standard_1005`` when the BIDS archive ships no ``electrodes.tsv``; :class:`moabb.datasets.BNCI2003_004` maps its 26 legacy Berlin channel labels to their modern 10-5 equivalents for exact positions; :class:`moabb.datasets.BNCI2014_002` applies an approximate 3x5 grid for its unlabeled small-Laplacian channels; and :class:`moabb.datasets.Zhang2017` applies the ``GSN-HydroCel-32`` montage in EGI sensor order. Adds the shared :func:`moabb.datasets.utils.set_neuroscan_montage` helper (:gh:`1089` by `Bruno Aristimunha`_).
7373
- Fix ``BaseEvaluation._aggregate_fold_results`` aborting the whole evaluation with ``TypeError: agg function failed [how->mean,dtype->object]`` when a single fold contributes a non-numeric ``score`` (e.g. an error fold). The numeric aggregation columns are now coerced with ``pandas.to_numeric(errors="coerce")`` before ``groupby.agg``, so a bad fold becomes ``NaN`` and is skipped instead of taking down every subject/pipeline (:gh:`1095` by `Bruno Aristimunha`_).
74+
- Fix :class:`moabb.evaluations.splitters.WithinSessionSplitter` and :class:`moabb.evaluations.splitters.WithinSubjectSplitter` overwriting an explicit ``n_splits`` passed through ``cv_kwargs`` with the ``n_folds`` default; the caller-provided ``n_splits`` now takes precedence, so a single holdout split can be requested directly via ``cv_class=StratifiedShuffleSplit, n_splits=1``. :class:`moabb.evaluations.WithinSessionEvaluation` and :class:`moabb.evaluations.WithinSubjectEvaluation` now honour the ``n_splits`` argument instead of always running 5 folds, and :class:`moabb.evaluations.splitters.WithinSubjectSplitter` now yields reproducible per-subject folds for a fixed ``random_state`` (:gh:`1106` by `Bruno Aristimunha`_).
7475
Code health
7576
~~~~~~~~~~~
7677
- Install CPU-only PyTorch wheels in CI by setting ``UV_TORCH_BACKEND=cpu`` in the test, braindecode, and docs workflows, so runners no longer download multi-GB CUDA builds of ``torch`` (pulled transitively via the ``deeplearning`` extra / braindecode) (:gh:`1083` by `Bhargav Kowshik`_).

moabb/evaluations/evaluations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _create_splitter(self):
8484
"""Create the WithinSessionSplitter for parallel evaluation."""
8585
cv_class, cv_kwargs = self._resolve_cv(StratifiedKFold)
8686
return WithinSessionSplitter(
87-
n_folds=5,
87+
n_folds=self.n_splits or 5,
8888
shuffle=True,
8989
random_state=self.random_state,
9090
cv_class=cv_class,
@@ -658,7 +658,7 @@ def _create_splitter(self):
658658
"""Create the WithinSubjectSplitter for parallel evaluation."""
659659
cv_class, cv_kwargs = self._resolve_cv(StratifiedKFold)
660660
return WithinSubjectSplitter(
661-
n_folds=5,
661+
n_folds=self.n_splits or 5,
662662
shuffle=True,
663663
random_state=self.random_state,
664664
cv_class=cv_class,

moabb/evaluations/splitters.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class WithinSessionSplitter(BaseCrossValidator):
5959
Defaults to ``StratifiedKFold``.
6060
cv_kwargs : dict
6161
Additional arguments to pass to the inner cross-validation strategy.
62+
An explicit ``n_splits`` provided here takes precedence over the
63+
``n_folds`` argument.
6264
6365
"""
6466

@@ -95,7 +97,7 @@ def __init__(
9597
("shuffle", shuffle),
9698
("random_state", self._rng),
9799
]:
98-
if p in params:
100+
if p in params and p not in cv_kwargs:
99101
self._cv_kwargs[p] = v
100102
self._last_split_metadata = None
101103

@@ -199,6 +201,8 @@ class WithinSubjectSplitter(BaseCrossValidator):
199201
Defaults to ``StratifiedKFold``.
200202
cv_kwargs : dict
201203
Additional arguments to pass to the inner cross-validation strategy.
204+
An explicit ``n_splits`` provided here takes precedence over the
205+
``n_folds`` argument.
202206
203207
"""
204208

@@ -235,7 +239,7 @@ def __init__(
235239
("shuffle", shuffle),
236240
("random_state", self._rng),
237241
]:
238-
if p in params:
242+
if p in params and p not in cv_kwargs:
239243
self._cv_kwargs[p] = v
240244
self._last_split_metadata = None
241245

@@ -280,16 +284,26 @@ def split(self, y, metadata):
280284
# Shuffle subjects if required
281285
# Convert to numpy array to avoid ArrowStringArray shuffle warning
282286
subjects = np.array(metadata["subject"].unique())
287+
# Reseed the RNG at each split() call so repeated calls are
288+
# reproducible. A single RNG is shared across subjects (instead of a
289+
# fresh per-subject one) to keep the fold sequence identical to the
290+
# legacy within-subject behaviour.
291+
rng = check_random_state(self.random_state) if self.shuffle else None
283292
if self.shuffle:
284-
self._rng.shuffle(subjects)
293+
rng.shuffle(subjects)
294+
295+
cv_kwargs = dict(self._cv_kwargs)
296+
params = inspect.signature(self.cv_class).parameters
297+
if self.shuffle and "random_state" in params:
298+
cv_kwargs["random_state"] = rng
285299

286300
for subject in subjects:
287301
subject_mask = metadata["subject"] == subject
288302
subject_indices = all_index[subject_mask]
289303
y_subject = y[subject_mask]
290304

291305
# Instantiate a new internal splitter for each subject
292-
splitter = self.cv_class(**self._cv_kwargs)
306+
splitter = self.cv_class(**cv_kwargs)
293307

294308
# Split using the cross-validation strategy across all sessions of the subject
295309
for train_ix, test_ix in splitter.split(subject_indices, y_subject):

moabb/tests/test_evaluations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,26 @@ def setup_method(self):
435435
)
436436

437437

438+
@pytest.mark.parametrize(
439+
"klass", [ev.WithinSessionEvaluation, ev.WithinSubjectEvaluation]
440+
)
441+
def test_within_n_splits_drives_n_folds(klass):
442+
"""n_splits sets the inner splitter's n_folds (defaults to 5 when unset)."""
443+
kw = {
444+
"paradigm": FakeImageryParadigm(),
445+
"datasets": [dataset],
446+
"hdf5_path": "res_test",
447+
}
448+
evals = {None: klass(**kw), 3: klass(n_splits=3, **kw)}
449+
try:
450+
for n, e in evals.items():
451+
assert e._create_splitter().n_folds == (n or 5)
452+
finally:
453+
for e in evals.values():
454+
if os.path.isfile(e.results.filepath):
455+
os.remove(e.results.filepath)
456+
457+
438458
class Test_CrossSubj(TestWithinSess):
439459
def setup_method(self):
440460
self.eval = ev.CrossSubjectEvaluation(

moabb/tests/test_splits.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,44 @@ def test_within_subject_get_n_splits(data):
487487
assert n_splits == 5 * 5 # 5 subjects, 5 folds each
488488

489489

490+
@pytest.mark.parametrize("splitter", [WithinSessionSplitter, WithinSubjectSplitter])
491+
def test_cv_kwargs_n_splits_not_overwritten(data, splitter):
492+
"""Explicit n_splits in cv_kwargs must not be overwritten by n_folds."""
493+
_, y, metadata = data
494+
495+
split = splitter(
496+
cv_class=StratifiedShuffleSplit,
497+
n_splits=1,
498+
test_size=0.25,
499+
shuffle=True,
500+
random_state=42,
501+
)
502+
503+
# The inner cv should keep the explicitly requested single split.
504+
assert split._cv_kwargs["n_splits"] == 1
505+
506+
if splitter == WithinSessionSplitter:
507+
num_groups = metadata.groupby(["subject", "session"]).ngroups
508+
else:
509+
num_groups = metadata["subject"].nunique()
510+
511+
splits = list(split.split(y, metadata))
512+
assert len(splits) == num_groups # one split per group, not n_folds per group
513+
514+
515+
@pytest.mark.parametrize("splitter", [WithinSessionSplitter, WithinSubjectSplitter])
516+
def test_within_split_is_reproducible(data, splitter):
517+
"""Repeated split() calls with a fixed seed must yield identical folds."""
518+
_, y, metadata = data
519+
split = splitter(shuffle=True, random_state=42)
520+
first = list(split.split(y, metadata))
521+
second = list(split.split(y, metadata))
522+
assert len(first) == len(second)
523+
for (train, test), (train_2, test_2) in zip(first, second):
524+
assert np.array_equal(train, train_2)
525+
assert np.array_equal(test, test_2)
526+
527+
490528
@pytest.mark.parametrize(
491529
"splitter", [CrossSessionSplitter, CrossSubjectSplitter, CrossDatasetSplitter]
492530
)

0 commit comments

Comments
 (0)