Skip to content

Commit b9d01a0

Browse files
Drive splitters with stock scikit-learn CVs via groups= argument (#1105)
Co-authored-by: Bru <b.aristimunha@gmail.com>
1 parent 535a4b3 commit b9d01a0

5 files changed

Lines changed: 336 additions & 35 deletions

File tree

docs/source/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Enhancements
3939
- Re-enable auto-execution of the Riemannian Artifact Rejection tutorial (``examples/advanced_examples/plot_riemannian_artifact_rejection.py``) now that pyRiemann 0.11 is on PyPI with per-potato metrics and ``method_combination`` support on ``PotatoField`` (by `Bruno Aristimunha`_)
4040
- Use NEMAR as the default download source for datasets with an assigned ``nemar_id``, while preserving existing dataset-specific downloaders as a fallback (by `Bruno Aristimunha`_).
4141
- Add :class:`moabb.datasets.preprocessing.EuclideanAlignment`, a trial-level Euclidean Alignment transformer (He & Wu 2020; Junqueira et al. 2024) that whitens each trial by the inverse square root of the Euclidean mean covariance to remove per-domain covariance shift before a (deep) model sees the data. Inductive and leakage-free by default (``fit`` learns the reference from training trials, ``transform`` re-applies it to unseen trials); ``fit_transform`` gives the transductive, per-recording form. Accepts an :class:`mne.BaseEpochs` or an ``(n_trials, n_channels, n_times)`` ndarray, uses a shrinkage covariance estimator (``"lwf"``) for robustness, and adds no new dependency (``pyriemann >= 0.11`` is already required). Distinct from :class:`pyriemann.transfer.TLCenter`, which recenters covariance *matrices* (:gh:`1108` by `Bruno Aristimunha`_).
42+
- Drive cross-validation folds with any stock scikit-learn cross-validator passed as ``cv_class``, controlled by a ``groups`` argument — a metadata column name, a list of column names (compound key, e.g. ``["subject", "session"]``), or a callable ``metadata -> array`` — together with callable ``cv_kwargs`` resolved against the metadata (e.g. ``cv_class=PredefinedSplit`` with a ``test_fold`` callable to target a single fold). ``groups`` is exposed on :class:`moabb.evaluations.WithinSessionEvaluation`, :class:`moabb.evaluations.WithinSubjectEvaluation`, :class:`moabb.evaluations.CrossSessionEvaluation` and :class:`moabb.evaluations.CrossSubjectEvaluation` and threaded to their splitters; each splitter keeps its default grouping (``"subject"`` / ``"session"`` / labels) when ``groups`` is ``None``. :class:`moabb.evaluations.splitters.CrossDatasetSplitter` gains ``groups`` (its ``group_column`` argument is now a deprecated alias) (:gh:`1104` by `Bruno Aristimunha`_).
4243

4344
API changes
4445
~~~~~~~~~~~

moabb/evaluations/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,13 @@ class BaseEvaluation(ABC):
301301
cv_kwargs : dict or None
302302
Keyword arguments passed to cv_class when constructing the splitter.
303303
Defaults to ``None``.
304+
groups : str, list of str, callable, or None
305+
What defines the cross-validation folds, forwarded to the evaluation's
306+
splitter as its ``groups`` argument: a metadata column name, a list of
307+
column names combined into a compound key, or a callable
308+
``metadata -> array``. When ``None`` (the default), the splitter's own
309+
default grouping applies (e.g. ``"subject"`` / ``"session"`` / labels).
310+
Defaults to ``None``.
304311
save_model : bool
305312
Save model after training, for each fold of cross-validation if needed.
306313
Defaults to ``False``.
@@ -359,6 +366,7 @@ def __init__(
359366
n_splits: Optional[int] = None,
360367
cv_class: Optional[type] = None,
361368
cv_kwargs: Optional[dict] = None,
369+
groups=None,
362370
save_model: bool = False,
363371
cache_config: Optional["CacheConfig"] = None,
364372
optuna: bool = False,
@@ -376,6 +384,7 @@ def __init__(
376384
self.n_splits = n_splits
377385
self.cv_class = cv_class
378386
self.cv_kwargs = {} if cv_kwargs is None else cv_kwargs
387+
self.groups = groups
379388
self.save_model = save_model
380389
self.cache_config = cache_config
381390
self.optuna = optuna

moabb/evaluations/evaluations.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class WithinSessionEvaluation(BaseEvaluation):
8383
def _create_splitter(self):
8484
"""Create the WithinSessionSplitter for parallel evaluation."""
8585
cv_class, cv_kwargs = self._resolve_cv(StratifiedKFold)
86+
if self.groups is not None:
87+
cv_kwargs = {**cv_kwargs, "groups": self.groups}
8688
return WithinSessionSplitter(
8789
n_folds=self.n_splits or 5,
8890
shuffle=True,
@@ -296,6 +298,8 @@ class CrossSessionEvaluation(BaseEvaluation):
296298
def _create_splitter(self):
297299
"""Create the CrossSessionSplitter for parallel evaluation."""
298300
cv_class, cv_kwargs = self._resolve_cv(LeaveOneGroupOut)
301+
if self.groups is not None:
302+
cv_kwargs = {**cv_kwargs, "groups": self.groups}
299303
return CrossSessionSplitter(
300304
cv_class=cv_class, random_state=self.random_state, **cv_kwargs
301305
)
@@ -476,6 +480,8 @@ def _create_splitter(self):
476480
default_kwargs = {"n_splits": self.n_splits}
477481

478482
cv_class, cv_kwargs = self._resolve_cv(default_class, default_kwargs)
483+
if self.groups is not None:
484+
cv_kwargs = {**cv_kwargs, "groups": self.groups}
479485
return CrossSubjectSplitter(
480486
cv_class=cv_class, random_state=self.random_state, **cv_kwargs
481487
)
@@ -657,6 +663,8 @@ class WithinSubjectEvaluation(BaseEvaluation):
657663
def _create_splitter(self):
658664
"""Create the WithinSubjectSplitter for parallel evaluation."""
659665
cv_class, cv_kwargs = self._resolve_cv(StratifiedKFold)
666+
if self.groups is not None:
667+
cv_kwargs = {**cv_kwargs, "groups": self.groups}
660668
return WithinSubjectSplitter(
661669
n_folds=self.n_splits or 5,
662670
shuffle=True,

0 commit comments

Comments
 (0)