|
8 | 8 | import pandas as pd |
9 | 9 | import pytest |
10 | 10 | import sklearn.base |
| 11 | +import sklearn.model_selection |
11 | 12 | from pyriemann.estimation import Covariances |
12 | 13 | from pyriemann.spatialfilters import CSP |
13 | 14 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA |
@@ -483,6 +484,58 @@ def test_incompatibility_error_message(self): |
483 | 484 | assert "requires at least 2 sessions" in error_msg |
484 | 485 |
|
485 | 486 |
|
| 487 | +def _flatten(x): |
| 488 | + """Flatten a 3D epochs array to 2D (n_samples, n_features).""" |
| 489 | + return x.reshape(len(x), -1) |
| 490 | + |
| 491 | + |
| 492 | +class _TargetSubjectSplitter(sklearn.model_selection.BaseCrossValidator): |
| 493 | + """Metadata-aware top-level splitter (test = target@session, train = others).""" |
| 494 | + |
| 495 | + metadata_columns = ("subject", "session") |
| 496 | + |
| 497 | + def __init__(self, target=None, test_session=None): |
| 498 | + self.target = target |
| 499 | + self.test_session = test_session |
| 500 | + |
| 501 | + def _iter_test_masks(self, X=None, y=None, groups=None): |
| 502 | + raise NotImplementedError |
| 503 | + |
| 504 | + def get_n_splits(self, metadata): |
| 505 | + return 1 |
| 506 | + |
| 507 | + def split(self, y, metadata): |
| 508 | + idx = metadata.index.values |
| 509 | + test_mask = (metadata["subject"] == self.target) & ( |
| 510 | + metadata["session"] == self.test_session |
| 511 | + ) |
| 512 | + train_mask = metadata["subject"] != self.target |
| 513 | + yield idx[train_mask.values], idx[test_mask.values] |
| 514 | + |
| 515 | + |
| 516 | +def test_cross_subject_with_metadata_aware_cv_class(): |
| 517 | + """A metadata-aware splitter passed via cv_class is used directly.""" |
| 518 | + ds = FakeDataset(["left_hand", "right_hand"], n_subjects=3, n_sessions=2, seed=3) |
| 519 | + evaluation = ev.CrossSubjectEvaluation( |
| 520 | + paradigm=FakeImageryParadigm(), |
| 521 | + datasets=[ds], |
| 522 | + overwrite=True, |
| 523 | + cv_class=_TargetSubjectSplitter, |
| 524 | + cv_kwargs={"target": 2, "test_session": "0"}, |
| 525 | + ) |
| 526 | + # subject/session are core columns and must not leak into additional_columns. |
| 527 | + assert "subject" not in evaluation.additional_columns |
| 528 | + assert "session" not in evaluation.additional_columns |
| 529 | + |
| 530 | + pipe = {"flat_lda": make_pipeline(FunctionTransformer(_flatten), LDA())} |
| 531 | + results = evaluation.process(pipe) |
| 532 | + |
| 533 | + # Exactly one fold: test = target subject at the requested session. |
| 534 | + assert len(results) == 1 |
| 535 | + assert {int(s) for s in results["subject"]} == {2} |
| 536 | + assert set(results["session"]) == {"0"} |
| 537 | + |
| 538 | + |
486 | 539 | class TestUtilEvaluation: |
487 | 540 | def test_save_model_cv(self): |
488 | 541 | model = Dummy() |
|
0 commit comments