|
19 | 19 | from mne._fiff.proj import make_eeg_average_ref_proj |
20 | 20 | from mne.cov import Covariance, _regularized_covariance |
21 | 21 | from mne.decoding._ged import ( |
| 22 | + _get_cov_def, |
22 | 23 | _get_restr_mat, |
23 | 24 | _handle_restr_mat, |
24 | | - _is_cov_pos_def, |
25 | | - _is_cov_symm_pos_semidef, |
| 25 | + _is_cov_symm, |
26 | 26 | _smart_ajd, |
27 | 27 | _smart_ged, |
28 | 28 | ) |
@@ -345,34 +345,27 @@ def test__handle_restr_mat_invalid_restr_type(): |
345 | 345 |
|
346 | 346 | def test_cov_validators(): |
347 | 347 | """Test that covariance validators indeed validate.""" |
348 | | - asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) |
| 348 | + asymm_indef = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) |
349 | 349 | sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) |
350 | 350 | pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) |
351 | 351 |
|
352 | | - assert not _is_cov_symm_pos_semidef(asymm) |
353 | | - assert _is_cov_symm_pos_semidef(sing_pos_semidef) |
354 | | - assert _is_cov_symm_pos_semidef(pos_def) |
| 352 | + assert not _is_cov_symm(asymm_indef) |
| 353 | + assert _get_cov_def(asymm_indef) == "indef" |
| 354 | + assert _get_cov_def(sing_pos_semidef) == "pos_semidef" |
| 355 | + assert _get_cov_def(pos_def) == "pos_def" |
355 | 356 |
|
356 | | - assert not _is_cov_pos_def(asymm) |
357 | | - assert not _is_cov_pos_def(sing_pos_semidef) |
358 | | - assert _is_cov_pos_def(pos_def) |
359 | 357 |
|
360 | | - |
361 | | -def test__is_cov_pos_def(): |
362 | | - """Test _is_cov_pos_def works.""" |
363 | | - asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) |
364 | | - sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) |
365 | | - pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) |
366 | | - assert not _is_cov_pos_def(asymm) |
367 | | - assert not _is_cov_pos_def(sing_pos_semidef) |
368 | | - assert _is_cov_pos_def(pos_def) |
369 | | - |
370 | | - |
371 | | -def test__smart_ajd_when_restr_mat_is_none(): |
372 | | - """Test _smart_ajd raises ValueError when restr_mat is None.""" |
| 358 | +def test__smart_ajd_raises(): |
| 359 | + """Test _smart_ajd raises proper ValueErrors.""" |
| 360 | + asymm_indef = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) |
373 | 361 | sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) |
374 | 362 | pos_def1 = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) |
375 | 363 | pos_def2 = np.array([[10, 1, 2], [1, 12, 3], [2, 3, 15]]) |
| 364 | + |
| 365 | + bad_covs = np.stack([sing_pos_semidef, asymm_indef, pos_def1]) |
| 366 | + with pytest.raises(ValueError, match="positive semi-definite"): |
| 367 | + _smart_ajd(bad_covs, restr_mat=pos_def2, weights=None) |
| 368 | + |
376 | 369 | bad_covs = np.stack([sing_pos_semidef, pos_def1, pos_def2]) |
377 | 370 | with pytest.raises(ValueError, match="positive definite"): |
378 | 371 | _smart_ajd(bad_covs, restr_mat=None, weights=None) |
|
0 commit comments