Skip to content

Commit 5b80814

Browse files
authored
Use metadata-aware cv_class directly; add tests and whatsnew
1 parent 457e479 commit 5b80814

4 files changed

Lines changed: 73 additions & 0 deletions

File tree

docs/source/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Enhancements
3838
- Skip zip extraction in :class:`moabb.datasets.GuttmannFlury2025` when files are already extracted, with ``/scratch`` fallback for NFS filesystems on compute nodes (by `Bruno Aristimunha`_)
3939
- Re-enable auto-execution of the Riemannian Artifact Rejection tutorial (``examples/advanced_examples/plot_riemannian_artifact_rejection.py``) now that pyRiemann 0.11 is on PyPI with per-potato metrics and ``method_combination`` support on ``PotatoField`` (by `Bruno Aristimunha`_)
4040
- Use NEMAR as the default download source for datasets with an assigned ``nemar_id``, while preserving existing dataset-specific downloaders as a fallback (by `Bruno Aristimunha`_).
41+
- Allow a metadata-aware splitter to be passed as ``cv_class`` to :class:`moabb.evaluations.CrossSubjectEvaluation`/:class:`moabb.evaluations.CrossSessionEvaluation` (and the ``CrossSubjectSplitter``/``CrossSessionSplitter`` wrappers). When ``cv_class`` follows the moabb convention (declares ``metadata_columns`` and implements ``split(self, y, metadata)``), it is used directly as the top-level splitter — receiving the full ``metadata`` — instead of being wrapped as the inner groups-CV, enabling metadata-driven folds (e.g. single-target cross-subject folds restricted to one session) through the public ``cv_class``/``cv_kwargs`` API (:gh:`1104` by `Bruno Aristimunha`_)
4142

4243
API changes
4344
~~~~~~~~~~~

moabb/evaluations/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,26 @@ def __init__(
388388
self.additional_columns = []
389389

390390
if self.cv_class is not None and hasattr(self.cv_class, "metadata_columns"):
391+
# ``metadata_columns`` may describe either extra *output* columns an
392+
# inner CV produces (e.g. LearningCurveSplitter -> data_size) or the
393+
# *input* metadata columns a metadata-aware top-level splitter
394+
# consumes (e.g. subject/session). The latter are already core result
395+
# columns and must not be stored as extra numeric result columns.
396+
reserved_columns = {
397+
"time",
398+
"dataset",
399+
"subject",
400+
"session",
401+
"n_samples",
402+
"n_channels",
403+
"pipeline",
404+
"n_samples_test",
405+
"n_classes",
406+
"score",
407+
}
391408
for col in self.cv_class.metadata_columns:
409+
if col in reserved_columns:
410+
continue
392411
if col not in self.additional_columns:
393412
self.additional_columns.append(col)
394413

moabb/tests/test_evaluations.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pandas as pd
99
import pytest
1010
import sklearn.base
11+
import sklearn.model_selection
1112
from pyriemann.estimation import Covariances
1213
from pyriemann.spatialfilters import CSP
1314
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
@@ -483,6 +484,58 @@ def test_incompatibility_error_message(self):
483484
assert "requires at least 2 sessions" in error_msg
484485

485486

487+
def _flatten(x):
488+
"""Flatten a 3D epochs array to 2D (n_samples, n_features)."""
489+
return x.reshape(len(x), -1)
490+
491+
492+
class _TargetSubjectSplitter(sklearn.model_selection.BaseCrossValidator):
493+
"""Metadata-aware top-level splitter (test = target@session, train = others)."""
494+
495+
metadata_columns = ("subject", "session")
496+
497+
def __init__(self, target=None, test_session=None):
498+
self.target = target
499+
self.test_session = test_session
500+
501+
def _iter_test_masks(self, X=None, y=None, groups=None):
502+
raise NotImplementedError
503+
504+
def get_n_splits(self, metadata):
505+
return 1
506+
507+
def split(self, y, metadata):
508+
idx = metadata.index.values
509+
test_mask = (metadata["subject"] == self.target) & (
510+
metadata["session"] == self.test_session
511+
)
512+
train_mask = metadata["subject"] != self.target
513+
yield idx[train_mask.values], idx[test_mask.values]
514+
515+
516+
def test_cross_subject_with_metadata_aware_cv_class():
517+
"""A metadata-aware splitter passed via cv_class is used directly."""
518+
ds = FakeDataset(["left_hand", "right_hand"], n_subjects=3, n_sessions=2, seed=3)
519+
evaluation = ev.CrossSubjectEvaluation(
520+
paradigm=FakeImageryParadigm(),
521+
datasets=[ds],
522+
overwrite=True,
523+
cv_class=_TargetSubjectSplitter,
524+
cv_kwargs={"target": 2, "test_session": "0"},
525+
)
526+
# subject/session are core columns and must not leak into additional_columns.
527+
assert "subject" not in evaluation.additional_columns
528+
assert "session" not in evaluation.additional_columns
529+
530+
pipe = {"flat_lda": make_pipeline(FunctionTransformer(_flatten), LDA())}
531+
results = evaluation.process(pipe)
532+
533+
# Exactly one fold: test = target subject at the requested session.
534+
assert len(results) == 1
535+
assert {int(s) for s in results["subject"]} == {2}
536+
assert set(results["session"]) == {"0"}
537+
538+
486539
class TestUtilEvaluation:
487540
def test_save_model_cv(self):
488541
model = Dummy()

test_save_path/fitted_model_0.pkl

130 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)