Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Enhancements
- Re-enable auto-execution of the Riemannian Artifact Rejection tutorial (``examples/advanced_examples/plot_riemannian_artifact_rejection.py``) now that pyRiemann 0.11 is on PyPI with per-potato metrics and ``method_combination`` support on ``PotatoField`` (by `Bruno Aristimunha`_)
- Use NEMAR as the default download source for datasets with an assigned ``nemar_id``, while preserving existing dataset-specific downloaders as a fallback (by `Bruno Aristimunha`_).
- Add :class:`moabb.datasets.preprocessing.EuclideanAlignment`, a trial-level Euclidean Alignment transformer (He & Wu 2020; Junqueira et al. 2024) that whitens each trial by the inverse square root of the Euclidean mean covariance to remove per-domain covariance shift before a (deep) model sees the data. Inductive and leakage-free by default (``fit`` learns the reference from training trials, ``transform`` re-applies it to unseen trials); ``fit_transform`` gives the transductive, per-recording form. Accepts an :class:`mne.BaseEpochs` or an ``(n_trials, n_channels, n_times)`` ndarray, uses a shrinkage covariance estimator (``"lwf"``) for robustness, and adds no new dependency (``pyriemann >= 0.11`` is already required). Distinct from :class:`pyriemann.transfer.TLCenter`, which recenters covariance *matrices* (:gh:`1108` by `Bruno Aristimunha`_).
- Drive cross-validation folds with any stock scikit-learn cross-validator passed as ``cv_class``, controlled by a ``groups`` argument — a metadata column name, a list of column names (compound key, e.g. ``["subject", "session"]``), or a callable ``metadata -> array`` — together with callable ``cv_kwargs`` resolved against the metadata (e.g. ``cv_class=PredefinedSplit`` with a ``test_fold`` callable to target a single fold). ``groups`` is exposed on :class:`moabb.evaluations.WithinSessionEvaluation`, :class:`moabb.evaluations.WithinSubjectEvaluation`, :class:`moabb.evaluations.CrossSessionEvaluation` and :class:`moabb.evaluations.CrossSubjectEvaluation` and threaded to their splitters; each splitter keeps its default grouping (``"subject"`` / ``"session"`` / labels) when ``groups`` is ``None``. :class:`moabb.evaluations.splitters.CrossDatasetSplitter` gains ``groups`` (its ``group_column`` argument is now a deprecated alias) (:gh:`1104` by `Bruno Aristimunha`_).

API changes
~~~~~~~~~~~
Expand Down
9 changes: 9 additions & 0 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,13 @@ class BaseEvaluation(ABC):
cv_kwargs : dict or None
Keyword arguments passed to cv_class when constructing the splitter.
Defaults to ``None``.
groups : str, list of str, callable, or None
What defines the cross-validation folds, forwarded to the evaluation's
splitter as its ``groups`` argument: a metadata column name, a list of
column names combined into a compound key, or a callable
``metadata -> array``. When ``None`` (the default), the splitter's own
default grouping applies (e.g. ``"subject"`` / ``"session"`` / labels).
Defaults to ``None``.
save_model : bool
Save model after training, for each fold of cross-validation if needed.
Defaults to ``False``.
Expand Down Expand Up @@ -359,6 +366,7 @@ def __init__(
n_splits: Optional[int] = None,
cv_class: Optional[type] = None,
cv_kwargs: Optional[dict] = None,
groups=None,
save_model: bool = False,
cache_config: Optional["CacheConfig"] = None,
optuna: bool = False,
Expand All @@ -376,6 +384,7 @@ def __init__(
self.n_splits = n_splits
self.cv_class = cv_class
self.cv_kwargs = {} if cv_kwargs is None else cv_kwargs
self.groups = groups
self.save_model = save_model
self.cache_config = cache_config
self.optuna = optuna
Expand Down
8 changes: 8 additions & 0 deletions moabb/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class WithinSessionEvaluation(BaseEvaluation):
def _create_splitter(self):
"""Create the WithinSessionSplitter for parallel evaluation."""
cv_class, cv_kwargs = self._resolve_cv(StratifiedKFold)
if self.groups is not None:
cv_kwargs = {**cv_kwargs, "groups": self.groups}
return WithinSessionSplitter(
n_folds=self.n_splits or 5,
shuffle=True,
Expand Down Expand Up @@ -296,6 +298,8 @@ class CrossSessionEvaluation(BaseEvaluation):
def _create_splitter(self):
"""Create the CrossSessionSplitter for parallel evaluation."""
cv_class, cv_kwargs = self._resolve_cv(LeaveOneGroupOut)
if self.groups is not None:
cv_kwargs = {**cv_kwargs, "groups": self.groups}
return CrossSessionSplitter(
cv_class=cv_class, random_state=self.random_state, **cv_kwargs
)
Expand Down Expand Up @@ -476,6 +480,8 @@ def _create_splitter(self):
default_kwargs = {"n_splits": self.n_splits}

cv_class, cv_kwargs = self._resolve_cv(default_class, default_kwargs)
if self.groups is not None:
cv_kwargs = {**cv_kwargs, "groups": self.groups}
return CrossSubjectSplitter(
cv_class=cv_class, random_state=self.random_state, **cv_kwargs
)
Expand Down Expand Up @@ -657,6 +663,8 @@ class WithinSubjectEvaluation(BaseEvaluation):
def _create_splitter(self):
"""Create the WithinSubjectSplitter for parallel evaluation."""
cv_class, cv_kwargs = self._resolve_cv(StratifiedKFold)
if self.groups is not None:
cv_kwargs = {**cv_kwargs, "groups": self.groups}
return WithinSubjectSplitter(
n_folds=self.n_splits or 5,
shuffle=True,
Expand Down
Loading
Loading