Skip to content

Commit 24800f9

Browse files
feat(preprocessing): add EuclideanAlignment trial-level transformer (#1109)
1 parent 0738017 commit 24800f9

5 files changed

Lines changed: 458 additions & 0 deletions

File tree

docs/source/api.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,21 @@ Utilities
324324
utils.plot_datasets_grid
325325
utils.plot_datasets_cluster
326326

327+
-------------
328+
Preprocessing
329+
-------------
330+
.. currentmodule:: moabb.datasets
331+
332+
Trial-level transformers applied to the epoched/array data, usable as
333+
pipeline steps (inductive in a cross-validation, transductive via
334+
``fit_transform`` on a single recording).
335+
336+
.. autosummary::
337+
:toctree: generated/
338+
:template: class.rst
339+
340+
preprocessing.EuclideanAlignment
341+
327342
Paradigms
328343
---------
329344
.. currentmodule:: moabb.paradigms

docs/source/whats_new.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Enhancements
3838
- Skip zip extraction in :class:`moabb.datasets.GuttmannFlury2025` when files are already extracted, with ``/scratch`` fallback for NFS filesystems on compute nodes (by `Bruno Aristimunha`_)
3939
- Re-enable auto-execution of the Riemannian Artifact Rejection tutorial (``examples/advanced_examples/plot_riemannian_artifact_rejection.py``) now that pyRiemann 0.11 is on PyPI with per-potato metrics and ``method_combination`` support on ``PotatoField`` (by `Bruno Aristimunha`_)
4040
- Use NEMAR as the default download source for datasets with an assigned ``nemar_id``, while preserving existing dataset-specific downloaders as a fallback (by `Bruno Aristimunha`_).
41+
- Add :class:`moabb.datasets.preprocessing.EuclideanAlignment`, a trial-level Euclidean Alignment transformer (He & Wu 2020; Junqueira et al. 2024) that whitens each trial by the inverse square root of the Euclidean mean covariance to remove per-domain covariance shift before a (deep) model sees the data. Inductive and leakage-free by default (``fit`` learns the reference from training trials, ``transform`` re-applies it to unseen trials); ``fit_transform`` gives the transductive, per-recording form. Accepts an :class:`mne.BaseEpochs` or an ``(n_trials, n_channels, n_times)`` ndarray, uses a shrinkage covariance estimator (``"lwf"``) for robustness, and adds no new dependency (``pyriemann >= 0.11`` is already required). Distinct from :class:`pyriemann.transfer.TLCenter`, which recenters covariance *matrices* (:gh:`1108` by `Bruno Aristimunha`_).
4142

4243
API changes
4344
~~~~~~~~~~~
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
r"""
2+
=====================================================
3+
Euclidean Alignment for cross-subject transfer
4+
=====================================================
5+
6+
EEG covariance statistics drift from subject to subject (and session to
7+
session): the same mental task produces differently-shaped data on each
8+
recording. That domain shift is the main reason a decoder trained on one set of
9+
subjects transfers poorly to a new one. **Euclidean Alignment** (EA) removes it
10+
with a single, label-free whitening step — cheap enough to put in front of any
11+
model, deep or classical [1]_.
12+
13+
In a systematic evaluation across MOABB motor-imagery datasets, Junqueira,
14+
Aristimunha, Chevallier & de Camargo (2024) [2]_ showed that aligning each
15+
recording with EA before training a *shared* deep model improved target-subject
16+
decoding by **+4.33%** on average and cut convergence time by **more than 70%** —
17+
for almost no compute and no extra labels. This example reproduces the core
18+
idea on the workhorse CSP+LDA motor-imagery pipeline using
19+
:class:`moabb.datasets.preprocessing.EuclideanAlignment`.
20+
21+
Each trial :math:`X_i` is whitened by the inverse square root of the
22+
**Euclidean (arithmetic) mean** of the per-trial covariances of its recording,
23+
24+
.. math::
25+
26+
\bar{C} = \frac{1}{N}\sum_{i=1}^{N} C_i,
27+
\qquad \tilde{X}_i = \bar{C}^{-1/2} X_i,
28+
29+
so after alignment every recording shares an identity-like average covariance
30+
and the subjects become comparable. We apply EA **per subject** (the
31+
transductive, per-recording form — :meth:`fit_transform` on one recording; it
32+
uses only the trial covariances, never the labels) and compare leave-one-subject
33+
-out decoding with and without it.
34+
"""
35+
36+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
37+
#
38+
# License: BSD (3-clause)
39+
40+
import matplotlib.pyplot as plt
41+
import mne
42+
import numpy as np
43+
from mne.decoding import CSP
44+
from pyriemann.estimation import Covariances
45+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
46+
from sklearn.metrics import roc_auc_score
47+
from sklearn.pipeline import make_pipeline
48+
from sklearn.preprocessing import LabelEncoder
49+
50+
import moabb
51+
from moabb.datasets import BNCI2014_001
52+
from moabb.datasets.preprocessing import EuclideanAlignment
53+
from moabb.paradigms import LeftRightImagery
54+
55+
56+
moabb.set_log_level("info")
57+
mne.set_log_level("WARNING") # keep the gallery output readable
58+
59+
###############################################################################
60+
# Load the data per subject
61+
# -------------------------
62+
#
63+
# We use the BCI Competition IV 2a dataset (:class:`moabb.datasets.BNCI2014_001`)
64+
# and the :class:`moabb.paradigms.LeftRightImagery` paradigm (left- vs right-hand
65+
# motor imagery, scored with ROC-AUC). We keep the trials of each subject
66+
# separate, because Euclidean Alignment is defined **per recording**.
67+
68+
paradigm = LeftRightImagery()
69+
dataset = BNCI2014_001()
70+
subjects = dataset.subject_list[:8]
71+
72+
# Pull each subject's trials once; X is (n_trials, n_channels, n_times).
73+
data = {}
74+
for subject in subjects:
75+
X, labels, _ = paradigm.get_data(dataset, [subject])
76+
data[subject] = (X, LabelEncoder().fit_transform(labels))
77+
78+
###############################################################################
79+
# Euclidean Alignment reduces the between-subject covariance shift
80+
# ----------------------------------------------------------------
81+
#
82+
# Before any classification, we can *see* what EA does. For every subject we
83+
# compute the mean trial covariance, then measure how far apart the subjects are
84+
# as the average pairwise distance between those mean covariances. EA pulls them
85+
# together — each subject's mean covariance becomes ~identity.
86+
87+
88+
def mean_covariance(X):
89+
"""Euclidean mean of the per-trial covariances of one recording."""
90+
return Covariances("oas").transform(X).mean(axis=0)
91+
92+
93+
def between_subject_dispersion(means):
94+
"""Average pairwise Frobenius distance between subject mean covariances."""
95+
dists = [
96+
np.linalg.norm(means[i] - means[j])
97+
for i in range(len(means))
98+
for j in range(i + 1, len(means))
99+
]
100+
return float(np.mean(dists))
101+
102+
103+
raw_means, aligned_means = [], []
104+
for subject in subjects:
105+
X, _ = data[subject]
106+
raw_means.append(mean_covariance(X))
107+
# Per-subject (transductive) Euclidean Alignment: label-free, leakage-free.
108+
X_aligned = EuclideanAlignment().fit_transform(X)
109+
aligned_means.append(mean_covariance(X_aligned))
110+
111+
dispersion = {
112+
"No alignment": between_subject_dispersion(raw_means),
113+
"Euclidean Alignment": between_subject_dispersion(aligned_means),
114+
}
115+
print("Between-subject covariance dispersion:", dispersion)
116+
117+
fig, ax = plt.subplots(figsize=(5, 4))
118+
ax.bar(dispersion.keys(), dispersion.values(), color=["#999999", "#0072B2"])
119+
ax.set_ylabel("Mean pairwise distance between\nsubject covariances (Frobenius)")
120+
ax.set_title("Euclidean Alignment shrinks the\nbetween-subject domain shift")
121+
fig.tight_layout()
122+
plt.show()
123+
124+
###############################################################################
125+
# Leave-one-subject-out decoding, with and without alignment
126+
# ----------------------------------------------------------
127+
#
128+
# Now the payoff. For each held-out subject we train a standard CSP+LDA pipeline
129+
# on the *other* subjects and test on the held-out one — the cross-subject
130+
# transfer setting. We run it twice: on the raw trials, and on trials that have
131+
# each been Euclidean-aligned per subject.
132+
#
133+
# CSP+LDA is a *Euclidean* classifier and is therefore sensitive to the
134+
# covariance shift EA removes. (Riemannian tangent-space pipelines already
135+
# recenter covariances internally, so they benefit less — EA is most valuable
136+
# for Euclidean and deep models, exactly the setting of [2]_.)
137+
138+
139+
def decode_loso(aligned):
140+
"""Leave-one-subject-out ROC-AUC, optionally with per-subject EA."""
141+
scores = []
142+
for test_subject in subjects:
143+
train_subjects = [s for s in subjects if s != test_subject]
144+
145+
def prep(subject):
146+
X, y = data[subject]
147+
if aligned:
148+
X = EuclideanAlignment().fit_transform(X)
149+
return X, y
150+
151+
X_train = np.concatenate([prep(s)[0] for s in train_subjects])
152+
y_train = np.concatenate([prep(s)[1] for s in train_subjects])
153+
X_test, y_test = prep(test_subject)
154+
155+
clf = make_pipeline(CSP(n_components=8), LDA())
156+
clf.fit(X_train, y_train)
157+
proba = clf.predict_proba(X_test)[:, 1]
158+
scores.append(roc_auc_score(y_test, proba))
159+
return np.array(scores)
160+
161+
162+
raw_scores = decode_loso(aligned=False)
163+
aligned_scores = decode_loso(aligned=True)
164+
165+
for subject, raw, aligned in zip(subjects, raw_scores, aligned_scores):
166+
print(f"subject {subject}: raw={raw:.3f} aligned={aligned:.3f}")
167+
print(
168+
f"mean: raw={raw_scores.mean():.3f} aligned={aligned_scores.mean():.3f} "
169+
f"(EA wins on {(aligned_scores > raw_scores).sum()}/{len(subjects)} subjects)"
170+
)
171+
172+
###############################################################################
173+
# A point per held-out subject: above the diagonal means Euclidean Alignment
174+
# helped that subject's cross-subject transfer.
175+
176+
fig, ax = plt.subplots(figsize=(5, 5))
177+
ax.scatter(raw_scores, aligned_scores, c="#0072B2", s=70, zorder=3)
178+
for subject, raw, aligned in zip(subjects, raw_scores, aligned_scores):
179+
ax.annotate(f"S{subject}", (raw, aligned), textcoords="offset points", xytext=(6, 0))
180+
lims = [min(raw_scores.min(), aligned_scores.min()) - 0.02, 1.0]
181+
ax.plot(lims, lims, "--", color="grey", zorder=1)
182+
ax.set_xlim(lims)
183+
ax.set_ylim(lims)
184+
ax.set_xlabel("ROC-AUC without alignment")
185+
ax.set_ylabel("ROC-AUC with Euclidean Alignment")
186+
ax.set_title("Cross-subject transfer (leave-one-subject-out)")
187+
ax.set_aspect("equal")
188+
fig.tight_layout()
189+
plt.show()
190+
191+
###############################################################################
192+
# Using it inside a MOABB evaluation
193+
# ----------------------------------
194+
#
195+
# Above we used the **transductive** per-recording form (``fit_transform`` on
196+
# each subject). :class:`~moabb.datasets.preprocessing.EuclideanAlignment` is
197+
# also a regular scikit-learn transformer, so its **inductive**, leakage-free
198+
# form drops straight into a pipeline for any MOABB evaluation: ``fit`` learns
199+
# the reference whitener from the training trials only and ``transform`` reuses
200+
# it on the test trials. For example::
201+
#
202+
# from moabb.evaluations import CrossSubjectEvaluation
203+
#
204+
# pipelines = {
205+
# "EA+CSP+LDA": make_pipeline(
206+
# EuclideanAlignment(), CSP(n_components=8), LDA()
207+
# )
208+
# }
209+
# evaluation = CrossSubjectEvaluation(paradigm=paradigm, datasets=[dataset])
210+
# results = evaluation.process(pipelines)
211+
#
212+
# For the full deep-learning story — where EA shines most, improving target
213+
# accuracy by +4.33% and cutting training time by >70% — see Junqueira et al.
214+
# (2024) [2]_.
215+
#
216+
# References
217+
# ----------
218+
# .. [1] He, H., & Wu, D. (2020). Transfer learning for brain-computer
219+
# interfaces: A Euclidean space data alignment approach. *IEEE
220+
# Transactions on Biomedical Engineering*, 67(2), 399-410.
221+
# https://doi.org/10.1109/TBME.2019.2913914
222+
# .. [2] Junqueira, B., Aristimunha, B., Chevallier, S., & de Camargo, R. Y.
223+
# (2024). A systematic evaluation of Euclidean alignment with deep
224+
# learning for EEG decoding. *Journal of Neural Engineering*, 21(3),
225+
# 036038. https://doi.org/10.1088/1741-2552/ad4f18

moabb/datasets/preprocessing.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.base import BaseEstimator, TransformerMixin
1010
from sklearn.pipeline import FunctionTransformer, Pipeline, _name_estimators
1111
from sklearn.utils import Bunch
12+
from sklearn.utils.validation import check_is_fitted
1213

1314
from moabb.datasets._channel_pick import pick_channels_for_modalities
1415

@@ -1102,3 +1103,116 @@ def get_resample_pipeline(sfreq):
11021103
func=methodcaller("resample", sfreq=sfreq, verbose=False),
11031104
display_name=f"Resample ({sfreq} Hz)",
11041105
)
1106+
1107+
1108+
class EuclideanAlignment(TransformerMixin, BaseEstimator):
1109+
r"""Euclidean Alignment of trials (He & Wu, 2020).
1110+
1111+
Euclidean Alignment (EA) removes the per-domain (subject / session /
1112+
recording) covariance shift that makes a model trained on one set of
1113+
recordings transfer poorly to another. It is the simplest member of a
1114+
larger family of trial-alignment methods — others recenter on the
1115+
Riemannian or log-Euclidean mean — and is the one most used with deep
1116+
networks because it is cheap, label-free, and leaves the data as raw trials
1117+
a network can ingest [He2020]_ [Junqueira2024]_.
1118+
1119+
Each trial is whitened by the inverse square root of a single reference
1120+
covariance,
1121+
1122+
.. math::
1123+
1124+
\bar{C} = \frac{1}{N} \sum_{i=1}^{N} C_i, \qquad
1125+
\tilde{X}_i = \bar{C}^{-1/2} X_i ,
1126+
1127+
where :math:`C_i` is the spatial covariance of trial :math:`X_i` and
1128+
:math:`\bar{C}` is their **arithmetic (Euclidean) mean**. After alignment
1129+
the trials share an identity-like average covariance, so the domain shift
1130+
that lived in the second-order statistics is gone.
1131+
1132+
The transformer is **inductive** by default: :meth:`fit` learns
1133+
:math:`\bar{C}^{-1/2}` from the *training* trials and :meth:`transform`
1134+
re-applies that same whitener to unseen trials, so no test information leaks
1135+
into the alignment (the leakage that the transductive, fit-on-everything
1136+
form silently introduces). Calling :meth:`fit_transform` on a single
1137+
recording recovers the usual transductive, per-recording EA people
1138+
hand-roll — same object, no second class.
1139+
1140+
Unlike :class:`pyriemann.transfer.TLCenter` (with ``metric="euclid"``),
1141+
which recenters covariance *matrices* for a Riemannian classifier, this
1142+
operates directly on the ``(n_trials, n_channels, n_times)`` trials, so it
1143+
drops in front of any time-series model (CSP, EEGNet, ...).
1144+
1145+
Parameters
1146+
----------
1147+
estimator : str, default "lwf"
1148+
Covariance estimator passed to
1149+
:func:`pyriemann.utils.covariance.covariances`. The shrinkage default
1150+
``"lwf"`` (Ledoit-Wolf) keeps the per-trial covariances symmetric
1151+
positive-definite — and hence the reference mean invertible — even on
1152+
short or noisy trials, where the plain sample covariance (``"scm"`` /
1153+
``"cov"``) can be ill-conditioned.
1154+
1155+
Attributes
1156+
----------
1157+
inv_sqrt_ref_ : ndarray, shape (n_channels, n_channels)
1158+
Inverse square root :math:`\bar{C}^{-1/2}` of the reference mean
1159+
covariance learned in :meth:`fit`; the whitening matrix applied in
1160+
:meth:`transform`.
1161+
1162+
See Also
1163+
--------
1164+
pyriemann.transfer.TLCenter
1165+
1166+
Notes
1167+
-----
1168+
Accepts an :class:`mne.BaseEpochs` (read via ``get_data``) or an ndarray of
1169+
shape ``(n_trials, n_channels, n_times)``; :meth:`transform` returns an
1170+
ndarray of the same shape. ``pyriemann >= 0.11`` is already a hard moabb
1171+
dependency, so this adds no new requirement.
1172+
1173+
References
1174+
----------
1175+
.. [He2020] He, H., & Wu, D. (2020). Transfer learning for brain-computer
1176+
interfaces: A Euclidean space data alignment approach. *IEEE
1177+
Transactions on Biomedical Engineering*, 67(2), 399-410.
1178+
https://doi.org/10.1109/TBME.2019.2913914
1179+
.. [Junqueira2024] Junqueira, B., Aristimunha, B., Chevallier, S., &
1180+
de Camargo, R. Y. (2024). A systematic evaluation of Euclidean alignment
1181+
with deep learning for EEG decoding. *Journal of Neural Engineering*,
1182+
21(3), 036038. https://doi.org/10.1088/1741-2552/ad4f18
1183+
"""
1184+
1185+
def __init__(self, estimator="lwf"):
1186+
self.estimator = estimator
1187+
1188+
@staticmethod
1189+
def _array(X):
1190+
"""Return trials as a float ``(n_trials, n_channels, n_times)`` ndarray."""
1191+
if hasattr(X, "get_data"): # mne Epochs
1192+
X = X.get_data(copy=False)
1193+
X = np.asarray(X, dtype=float)
1194+
if X.ndim != 3:
1195+
raise ValueError(
1196+
"EuclideanAlignment expects trials shaped "
1197+
f"(n_trials, n_channels, n_times), got a {X.ndim}D input."
1198+
)
1199+
return X
1200+
1201+
def fit(self, X, y=None):
1202+
# Lazy import: pyriemann.utils.base emits a DeprecationWarning at import
1203+
# time and this core module is imported almost everywhere, so only
1204+
# EuclideanAlignment users pay it. These paths are valid for the declared
1205+
# pyriemann >= 0.11 floor and match the rest of moabb (pipelines.csp,
1206+
# pipelines.classification). The Euclidean mean is the arithmetic mean of
1207+
# the per-trial covariances, so no mean_covariance() call is needed.
1208+
from pyriemann.utils.base import invsqrtm
1209+
from pyriemann.utils.covariance import covariances
1210+
1211+
covs = covariances(self._array(X), estimator=self.estimator)
1212+
self.inv_sqrt_ref_ = invsqrtm(covs.mean(axis=0))
1213+
return self
1214+
1215+
def transform(self, X):
1216+
check_is_fitted(self, "inv_sqrt_ref_")
1217+
# (n_chans, n_chans) @ (n_trials, n_chans, n_times) -> (n_trials, n_chans, n_times)
1218+
return np.matmul(self.inv_sqrt_ref_, self._array(X))

0 commit comments

Comments
 (0)