Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/decoding/ssd_spatial_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
frequency band of interest and the noise covariance based on surrounding
frequencies.
"""

# Author: Denis A. Engemann <denis.engemann@gmail.com>
# Victoria Peterson <victoriapeterson09@gmail.com>
# License: BSD-3-Clause
Expand Down Expand Up @@ -82,8 +83,8 @@
ssd_sources, sfreq=raw.info["sfreq"], n_fft=4096
)

# Get spec_ratio information (already sorted).
# Note that this is not necessary if sort_by_spectral_ratio=True (default).
# Get spec_ratio information (already sorted)
# Note that this is not necessary if sort_by_spectral_ratio=True (default)
spec_ratio, sorter = ssd.get_spectral_ratio(ssd_sources)

# Plot spectral ratio (see Eq. 24 in Nikulin et al., 2011).
Expand Down
33 changes: 32 additions & 1 deletion mne/decoding/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger,
)
from ._covs_ged import _ssd_estimate
from ._mod_ged import _ssd_mod
from ._mod_ged import _get_spectral_ratio, _ssd_mod
from .base import _GEDTransformer


Expand Down Expand Up @@ -289,6 +289,37 @@ def fit_transform(self, X, y=None, **fit_params):
# use parent TransformerMixin method but with custom docstring
return super().fit_transform(X, y=y, **fit_params)

def get_spectral_ratio(self, ssd_sources):
"""Get the spectal signal-to-noise ratio for each spatial filter.

Spectral ratio measure for best n_components selection
See :footcite:`NikulinEtAl2011`, Eq. (24).

Parameters
----------
ssd_sources : array
Data projected to SSD space.

Returns
-------
spec_ratio : array, shape (n_channels)
Array with the sprectal ratio value for each component.
sorter_spec : array, shape (n_channels)
Array of indices for sorting spec_ratio.

References
----------
.. footbibliography::
"""
spec_ratio, sorter_spec = _get_spectral_ratio(
ssd_sources=ssd_sources,
sfreq=self.sfreq_,
n_fft=self.n_fft_,
freqs_signal=self.freqs_signal_,
freqs_noise=self.freqs_noise_,
)
return spec_ratio, sorter_spec

def inverse_transform(self):
"""Not implemented yet."""
raise NotImplementedError("inverse_transform is not yet available.")
Expand Down
38 changes: 38 additions & 0 deletions mne/decoding/tests/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,44 @@ def test_picks_arg():
ssd.fit(X).transform(X)


def test_get_spectral_ratio():
"""Test that method is the same as function in _mod_ged.py."""
X, _, _ = simulate_data()
sf = 250
n_channels = X.shape[0]
info = create_info(ch_names=n_channels, sfreq=sf, ch_types="eeg")

# Init
filt_params_signal = dict(
l_freq=freqs_sig[0],
h_freq=freqs_sig[1],
l_trans_bandwidth=1,
h_trans_bandwidth=1,
)
filt_params_noise = dict(
l_freq=freqs_noise[0],
h_freq=freqs_noise[1],
l_trans_bandwidth=1,
h_trans_bandwidth=1,
)

ssd = SSD(
info,
filt_params_signal,
filt_params_noise,
n_components=None,
sort_by_spectral_ratio=False,
)
ssd.fit(X)
ssd_sources = ssd.transform(X)
spec_ratio_ssd, sorter_spec_ssd = ssd.get_spectral_ratio(ssd_sources)
spec_ratio_ged, sorter_spec_ged = _get_spectral_ratio(
ssd_sources, ssd.sfreq_, ssd.n_fft_, ssd.freqs_signal_, ssd.freqs_noise_
)
assert_array_equal(spec_ratio_ssd, spec_ratio_ged)
assert_array_equal(sorter_spec_ssd, sorter_spec_ged)


@pytest.mark.filterwarnings("ignore:.*invalid value encountered in divide.*")
@pytest.mark.filterwarnings("ignore:.*is longer than.*")
@parametrize_with_checks(
Expand Down
Loading