|
10 | 10 | Channels are concatenated and rescaled to create features vectors that will be |
11 | 11 | fed into a logistic regression. |
12 | 12 | """ |
| 13 | + |
13 | 14 | # Authors: Alexandre Barachant <alexandre.barachant@gmail.com> |
14 | 15 | # |
15 | 16 | # License: BSD-3-Clause |
|
26 | 27 | from sklearn.pipeline import make_pipeline |
27 | 28 | from sklearn.preprocessing import MinMaxScaler |
28 | 29 |
|
29 | | -from mne import Epochs, EvokedArray, create_info, io, pick_types, read_events |
| 30 | +from mne import Epochs, io, pick_types, read_events |
30 | 31 | from mne.datasets import sample |
31 | | -from mne.decoding import Vectorizer |
32 | | -from mne.preprocessing import Xdawn |
| 32 | +from mne.decoding import Vectorizer, XdawnTransformer, get_spatial_filter_from_estimator |
33 | 33 |
|
34 | 34 | print(__doc__) |
35 | 35 |
|
|
71 | 71 |
|
72 | 72 | # Create classification pipeline |
73 | 73 | clf = make_pipeline( |
74 | | - Xdawn(n_components=n_filter), |
| 74 | + XdawnTransformer(n_components=n_filter), |
75 | 75 | Vectorizer(), |
76 | 76 | MinMaxScaler(), |
77 | 77 | OneVsRestClassifier(LogisticRegression(penalty="l1", solver="liblinear")), |
78 | 78 | ) |
79 | 79 |
|
80 | | -# Get the labels |
81 | | -labels = epochs.events[:, -1] |
| 80 | +# Get the data and labels |
| 81 | +# X is of shape (n_epochs, n_channels, n_times) |
| 82 | +X = epochs.get_data(copy=False) |
| 83 | +y = epochs.events[:, -1] |
82 | 84 |
|
83 | 85 | # Cross validator |
84 | 86 | cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) |
85 | 87 |
|
86 | 88 | # Do cross-validation |
87 | | -preds = np.empty(len(labels)) |
88 | | -for train, test in cv.split(epochs, labels): |
89 | | - clf.fit(epochs[train], labels[train]) |
90 | | - preds[test] = clf.predict(epochs[test]) |
| 89 | +preds = np.empty(len(y)) |
| 90 | +for train, test in cv.split(epochs, y): |
| 91 | + clf.fit(X[train], y[train]) |
| 92 | + preds[test] = clf.predict(X[test]) |
91 | 93 |
|
92 | 94 | # Classification report |
93 | 95 | target_names = ["aud_l", "aud_r", "vis_l", "vis_r"] |
94 | | -report = classification_report(labels, preds, target_names=target_names) |
| 96 | +report = classification_report(y, preds, target_names=target_names) |
95 | 97 | print(report) |
96 | 98 |
|
97 | 99 | # Normalized confusion matrix |
98 | | -cm = confusion_matrix(labels, preds) |
| 100 | +cm = confusion_matrix(y, preds) |
99 | 101 | cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis] |
100 | 102 |
|
101 | 103 | # Plot confusion matrix |
|
109 | 111 | ax.set(ylabel="True label", xlabel="Predicted label") |
110 | 112 |
|
111 | 113 | # %% |
112 | | -# The ``patterns_`` attribute of a fitted Xdawn instance (here from the last |
113 | | -# cross-validation fold) can be used for visualization. |
114 | | - |
115 | | -fig, axes = plt.subplots( |
116 | | - nrows=len(event_id), |
117 | | - ncols=n_filter, |
118 | | - figsize=(n_filter, len(event_id) * 2), |
119 | | - layout="constrained", |
| 114 | +# Patterns of a fitted XdawnTransformer instance (here from the last |
| 115 | +# cross-validation fold) can be visualized using SpatialFilter container. |
| 116 | + |
| 117 | +# Instantiate SpatialFilter |
| 118 | +spf = get_spatial_filter_from_estimator( |
| 119 | + clf, info=epochs.info, step_name="xdawntransformer" |
| 120 | +) |
| 121 | + |
| 122 | +# Let's first examine the scree plot of generalized eigenvalues |
| 123 | +# for each class. |
| 124 | +spf.plot_scree(title="") |
| 125 | + |
| 126 | +# We can see that for all four classes ~five largest components |
| 127 | +# capture most of the variance, let's plot their patterns. |
| 128 | +# Each class will now return its own figure |
| 129 | +components_to_plot = np.arange(5) |
| 130 | +figs = spf.plot_patterns( |
| 131 | + # Indices of patterns to plot, |
| 132 | + # we will plot the first three for each class |
| 133 | + components=components_to_plot, |
| 134 | + show=False, # to set the titles below |
120 | 135 | ) |
121 | | -fitted_xdawn = clf.steps[0][1] |
122 | | -info = create_info(epochs.ch_names, 1, epochs.get_channel_types()) |
123 | | -info.set_montage(epochs.get_montage()) |
124 | | -for ii, cur_class in enumerate(sorted(event_id)): |
125 | | - cur_patterns = fitted_xdawn.patterns_[cur_class] |
126 | | - pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0) |
127 | | - pattern_evoked.plot_topomap( |
128 | | - times=np.arange(n_filter), |
129 | | - time_format="Component %d" if ii == 0 else "", |
130 | | - colorbar=False, |
131 | | - show_names=False, |
132 | | - axes=axes[ii], |
133 | | - show=False, |
134 | | - ) |
135 | | - axes[ii, 0].set(ylabel=cur_class) |
| 136 | + |
| 137 | +# Set the class titles |
| 138 | +event_id_reversed = {v: k for k, v in event_id.items()} |
| 139 | +for fig, class_idx in zip(figs, clf[0].classes_): |
| 140 | + class_name = event_id_reversed[class_idx] |
| 141 | + fig.suptitle(class_name, fontsize=16) |
| 142 | + |
136 | 143 |
|
137 | 144 | # %% |
138 | 145 | # References |
|
0 commit comments