Skip to content

Commit 457e479

Browse files
authored
Allow metadata-aware splitter as cv_class used directly
1 parent 9c2f3c9 commit 457e479

2 files changed

Lines changed: 123 additions & 0 deletions

File tree

moabb/evaluations/splitters.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ def _splitter_metadata(splitter):
2828
return None
2929

3030

31+
def _is_metadata_aware_cv(cv_class):
32+
"""Return True when ``cv_class`` is a metadata-aware top-level splitter.
33+
34+
Such a splitter follows the moabb convention: it declares
35+
``metadata_columns`` and its ``split`` method accepts ``(self, y, metadata)``
36+
(rather than the sklearn ``(self, X, y, groups)`` signature). When this is
37+
the case, the splitter is used directly as the top-level splitter instead of
38+
being wrapped as the inner groups-CV.
39+
"""
40+
if not hasattr(cv_class, "metadata_columns"):
41+
return False
42+
try:
43+
params = inspect.signature(cv_class.split).parameters
44+
except (TypeError, ValueError):
45+
return False
46+
return "metadata" in params
47+
48+
3149
class WithinSessionSplitter(BaseCrossValidator):
3250
"""Data splitter for within session evaluation.
3351
@@ -388,6 +406,11 @@ def __init__(
388406

389407
# Detect whether the cv_class uses the groups parameter
390408
self._cv_uses_groups = issubclass(cv_class, GroupsConsumerMixin)
409+
# Detect whether the cv_class is a metadata-aware top-level splitter
410+
# (i.e. it follows the moabb convention and declares ``metadata_columns``
411+
# and implements ``split(self, y, metadata)``). In that case, it is used
412+
# directly instead of being wrapped as the inner groups-CV.
413+
self._cv_is_metadata_aware = _is_metadata_aware_cv(cv_class)
391414
self._last_split_metadata = None
392415

393416
def get_n_splits(self, metadata):
@@ -409,6 +432,8 @@ def get_n_splits(self, metadata):
409432
n_splits: int
410433
The number of splits for the cross-validation
411434
"""
435+
if self._cv_is_metadata_aware:
436+
return self.cv_class(**self._cv_kwargs).get_n_splits(metadata)
412437
subjects = metadata["subject"].unique()
413438
n_splits = 0
414439
for subject in subjects: # noqa: B007 — referenced via @subject in pandas query below
@@ -429,6 +454,16 @@ def split(self, y, metadata):
429454
# here, I am getting the index across all the subject
430455
all_index = metadata.index.values
431456
self._last_split_metadata = None
457+
458+
# When the cv_class is a metadata-aware top-level splitter, delegate the
459+
# fold creation to it directly, forwarding the full metadata.
460+
if self._cv_is_metadata_aware:
461+
splitter = self.cv_class(**self._cv_kwargs)
462+
for train_idx, test_idx in splitter.split(y, metadata):
463+
self._last_split_metadata = _splitter_metadata(splitter)
464+
yield train_idx, test_idx
465+
return
466+
432467
# I check how many subjects are here:
433468
subjects = metadata["subject"].unique()
434469

@@ -539,6 +574,11 @@ def __init__(
539574

540575
# Detect whether the cv_class uses the groups parameter
541576
self._cv_uses_groups = issubclass(cv_class, GroupsConsumerMixin)
577+
# Detect whether the cv_class is a metadata-aware top-level splitter
578+
# (i.e. it follows the moabb convention and declares ``metadata_columns``
579+
# and implements ``split(self, y, metadata)``). In that case, it is used
580+
# directly instead of being wrapped as the inner groups-CV.
581+
self._cv_is_metadata_aware = _is_metadata_aware_cv(cv_class)
542582
self._last_split_metadata = None
543583

544584
def get_n_splits(self, metadata):
@@ -562,6 +602,8 @@ def get_n_splits(self, metadata):
562602
"""
563603

564604
splitter = self.cv_class(**self._cv_kwargs)
605+
if self._cv_is_metadata_aware:
606+
return splitter.get_n_splits(metadata)
565607
get_n_splits_kwargs = {"X": metadata.index}
566608
if self._cv_uses_groups:
567609
get_n_splits_kwargs["groups"] = metadata["subject"]
@@ -575,6 +617,14 @@ def split(self, y, metadata):
575617
splitter = self.cv_class(**self._cv_kwargs)
576618
self._last_split_metadata = None
577619

620+
# When the cv_class is a metadata-aware top-level splitter, delegate the
621+
# fold creation to it directly, forwarding the full metadata.
622+
if self._cv_is_metadata_aware:
623+
for train_idx, test_idx in splitter.split(y, metadata):
624+
self._last_split_metadata = _splitter_metadata(splitter)
625+
yield train_idx, test_idx
626+
return
627+
578628
# Only pass groups to cv_classes that actually use them
579629
# (detected via GroupsConsumerMixin). This avoids the
580630
# "The groups parameter is ignored" warning from e.g. TimeSeriesSplit.

moabb/tests/test_splits.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33
from sklearn.model_selection import (
4+
BaseCrossValidator,
45
GroupShuffleSplit,
56
KFold,
67
LeaveOneGroupOut,
@@ -673,3 +674,75 @@ def test_cross_dataset_requires_group_column(data):
673674
splitter = CrossDatasetSplitter(group_column="does_not_exist")
674675
with pytest.raises(ValueError):
675676
list(splitter.split(y, metadata))
677+
678+
679+
class _TargetSubjectSplitter(BaseCrossValidator):
680+
"""Metadata-aware splitter used to test direct ``cv_class`` delegation.
681+
682+
The test set is restricted to one ``(subject, session)`` and the training
683+
set is all the data from the other subjects.
684+
"""
685+
686+
metadata_columns = ("subject", "session")
687+
688+
def __init__(self, target=None, test_session=None):
689+
self.target = target
690+
self.test_session = test_session
691+
692+
def _iter_test_masks(self, X=None, y=None, groups=None):
693+
raise NotImplementedError
694+
695+
def get_n_splits(self, metadata):
696+
return 1
697+
698+
def split(self, y, metadata):
699+
all_index = metadata.index.values
700+
test_mask = (metadata["subject"] == self.target) & (
701+
metadata["session"] == self.test_session
702+
)
703+
train_mask = metadata["subject"] != self.target
704+
yield all_index[train_mask.values], all_index[test_mask.values]
705+
706+
707+
def test_metadata_aware_cv_class_used_directly(data):
708+
"""A cv_class declaring ``metadata_columns`` is used as top-level splitter."""
709+
_, y, metadata = data
710+
target = metadata["subject"].unique()[1]
711+
test_session = metadata["session"].unique()[0]
712+
713+
splitter = CrossSubjectSplitter(
714+
cv_class=_TargetSubjectSplitter, target=target, test_session=test_session
715+
)
716+
717+
assert splitter.get_n_splits(metadata) == 1
718+
719+
splits = list(splitter.split(y, metadata))
720+
assert len(splits) == 1
721+
train, test = splits[0]
722+
723+
test_meta = metadata.loc[test]
724+
train_meta = metadata.loc[train]
725+
# Test set is exactly the target subject at the requested session.
726+
assert set(test_meta["subject"]) == {target}
727+
assert set(test_meta["session"]) == {test_session}
728+
# Training set excludes the target subject entirely.
729+
assert target not in set(train_meta["subject"])
730+
assert len(set(train) & set(test)) == 0
731+
732+
733+
def test_metadata_aware_cv_class_cross_session(data):
734+
"""CrossSessionSplitter also delegates to a metadata-aware cv_class."""
735+
_, y, metadata = data
736+
target = metadata["subject"].unique()[1]
737+
test_session = metadata["session"].unique()[0]
738+
739+
splitter = CrossSessionSplitter(
740+
cv_class=_TargetSubjectSplitter, target=target, test_session=test_session
741+
)
742+
743+
assert splitter.get_n_splits(metadata) == 1
744+
splits = list(splitter.split(y, metadata))
745+
assert len(splits) == 1
746+
train, test = splits[0]
747+
assert set(metadata.loc[test]["subject"]) == {target}
748+
assert len(set(train) & set(test)) == 0

0 commit comments

Comments
 (0)