@@ -28,6 +28,24 @@ def _splitter_metadata(splitter):
2828 return None
2929
3030
31+ def _is_metadata_aware_cv (cv_class ):
32+ """Return True when ``cv_class`` is a metadata-aware top-level splitter.
33+
34+ Such a splitter follows the moabb convention: it declares
35+ ``metadata_columns`` and its ``split`` method accepts ``(self, y, metadata)``
36+ (rather than the sklearn ``(self, X, y, groups)`` signature). When this is
37+ the case, the splitter is used directly as the top-level splitter instead of
38+ being wrapped as the inner groups-CV.
39+ """
40+ if not hasattr (cv_class , "metadata_columns" ):
41+ return False
42+ try :
43+ params = inspect .signature (cv_class .split ).parameters
44+ except (TypeError , ValueError ):
45+ return False
46+ return "metadata" in params
47+
48+
3149class WithinSessionSplitter (BaseCrossValidator ):
3250 """Data splitter for within session evaluation.
3351
@@ -388,6 +406,11 @@ def __init__(
388406
389407 # Detect whether the cv_class uses the groups parameter
390408 self ._cv_uses_groups = issubclass (cv_class , GroupsConsumerMixin )
409+ # Detect whether the cv_class is a metadata-aware top-level splitter
410+ # (i.e. it follows the moabb convention and declares ``metadata_columns``
411+ # and implements ``split(self, y, metadata)``). In that case, it is used
412+ # directly instead of being wrapped as the inner groups-CV.
413+ self ._cv_is_metadata_aware = _is_metadata_aware_cv (cv_class )
391414 self ._last_split_metadata = None
392415
393416 def get_n_splits (self , metadata ):
@@ -409,6 +432,8 @@ def get_n_splits(self, metadata):
409432 n_splits: int
410433 The number of splits for the cross-validation
411434 """
435+ if self ._cv_is_metadata_aware :
436+ return self .cv_class (** self ._cv_kwargs ).get_n_splits (metadata )
412437 subjects = metadata ["subject" ].unique ()
413438 n_splits = 0
414439 for subject in subjects : # noqa: B007 — referenced via @subject in pandas query below
@@ -429,6 +454,16 @@ def split(self, y, metadata):
429454 # here, I am getting the index across all the subject
430455 all_index = metadata .index .values
431456 self ._last_split_metadata = None
457+
458+ # When the cv_class is a metadata-aware top-level splitter, delegate the
459+ # fold creation to it directly, forwarding the full metadata.
460+ if self ._cv_is_metadata_aware :
461+ splitter = self .cv_class (** self ._cv_kwargs )
462+ for train_idx , test_idx in splitter .split (y , metadata ):
463+ self ._last_split_metadata = _splitter_metadata (splitter )
464+ yield train_idx , test_idx
465+ return
466+
432467 # I check how many subjects are here:
433468 subjects = metadata ["subject" ].unique ()
434469
@@ -539,6 +574,11 @@ def __init__(
539574
540575 # Detect whether the cv_class uses the groups parameter
541576 self ._cv_uses_groups = issubclass (cv_class , GroupsConsumerMixin )
577+ # Detect whether the cv_class is a metadata-aware top-level splitter
578+ # (i.e. it follows the moabb convention and declares ``metadata_columns``
579+ # and implements ``split(self, y, metadata)``). In that case, it is used
580+ # directly instead of being wrapped as the inner groups-CV.
581+ self ._cv_is_metadata_aware = _is_metadata_aware_cv (cv_class )
542582 self ._last_split_metadata = None
543583
544584 def get_n_splits (self , metadata ):
@@ -562,6 +602,8 @@ def get_n_splits(self, metadata):
562602 """
563603
564604 splitter = self .cv_class (** self ._cv_kwargs )
605+ if self ._cv_is_metadata_aware :
606+ return splitter .get_n_splits (metadata )
565607 get_n_splits_kwargs = {"X" : metadata .index }
566608 if self ._cv_uses_groups :
567609 get_n_splits_kwargs ["groups" ] = metadata ["subject" ]
@@ -575,6 +617,14 @@ def split(self, y, metadata):
575617 splitter = self .cv_class (** self ._cv_kwargs )
576618 self ._last_split_metadata = None
577619
620+ # When the cv_class is a metadata-aware top-level splitter, delegate the
621+ # fold creation to it directly, forwarding the full metadata.
622+ if self ._cv_is_metadata_aware :
623+ for train_idx , test_idx in splitter .split (y , metadata ):
624+ self ._last_split_metadata = _splitter_metadata (splitter )
625+ yield train_idx , test_idx
626+ return
627+
578628 # Only pass groups to cv_classes that actually use them
579629 # (detected via GroupsConsumerMixin). This avoids the
580630 # "The groups parameter is ignored" warning from e.g. TimeSeriesSplit.
0 commit comments