Skip to content

Commit f4853ce

Browse files
FIX: Fix bug with fitting coil order and GOF (mne-tools#13525)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 718aabf commit f4853ce

6 files changed

Lines changed: 83 additions & 28 deletions

File tree

doc/changes/dev/13525.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug where :func:`mne.chpi.refit_hpi` did not take ``gof_limit`` into account when fitting HPI order, by `Eric Larson`_

doc/sphinxext/directive_formatting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def check_directive_formatting(*args):
6060
# another directive/another directive's content)
6161
if idx == 0:
6262
continue
63-
dir_pattern = r"\.\. [a-zA-Z]+::"
63+
dir_pattern = r"^\s*\.\. \w+::" # line might start with whitespace
6464
head_pattern = r"^[-|=|\^]+$"
6565
directive = re.search(dir_pattern, line)
6666
if directive is not None:
@@ -84,5 +84,5 @@ def check_directive_formatting(*args):
8484
if bad:
8585
sphinx_logger.warning(
8686
f"{source_type} '{name}' is missing a blank line before the "
87-
f"directive '{directive.group()}'"
87+
f"directive '{directive.group()}' on line {idx + 1}"
8888
)

examples/decoding/decoding_time_generalization_conditions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
=========================================================================
77
88
This example runs the analysis described in :footcite:`KingDehaene2014`. It
9-
illustrates how one can
10-
fit a linear classifier to identify a discriminatory topography at a given time
11-
instant and subsequently assess whether this linear model can accurately
12-
predict all of the time samples of a second set of conditions.
9+
illustrates how one can fit a linear classifier to identify a discriminatory
10+
topography at a given time instant and subsequently assess whether this linear
11+
model can accurately predict all of the time samples of a second set of conditions.
1312
"""
1413
# Authors: Jean-Rémi King <jeanremi.king@gmail.com>
1514
# Alexandre Gramfort <alexandre.gramfort@inria.fr>

mne/chpi.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -579,27 +579,37 @@ def _chpi_objective(x, coil_dev_rrs, coil_head_rrs):
579579
return d.sum()
580580

581581

582-
def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs):
582+
def _fit_chpi_quat(coil_dev_rrs, coil_head_rrs, *, quat=None):
583583
"""Fit rotation and translation (quaternion) parameters for cHPI coils."""
584584
denom = np.linalg.norm(coil_head_rrs - np.mean(coil_head_rrs, axis=0))
585585
denom *= denom
586586
# We could try to solve it the analytic way:
587587
# TODO someday we could choose to weight these points by their goodness
588588
# of fit somehow, see also https://github.com/mne-tools/mne-python/issues/11330
589-
quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0]
589+
if quat is None:
590+
quat = _fit_matched_points(coil_dev_rrs, coil_head_rrs)[0]
590591
gof = 1.0 - _chpi_objective(quat, coil_dev_rrs, coil_head_rrs) / denom
591592
return quat, gof
592593

593594

594-
def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix=""):
595+
def _fit_coil_order_dev_head_trans(
596+
dev_pnts, head_pnts, *, bias=True, gofs=None, gof_limit=0.98, prefix=""
597+
):
595598
"""Compute Device to Head transform allowing for permutiatons of points."""
599+
n_coils = len(dev_pnts)
596600
id_quat = np.zeros(6)
597-
best_order = None
601+
best_order = np.full(n_coils, -1, dtype=int)
598602
best_g = -999
599603
best_quat = id_quat
600-
for this_order in itertools.permutations(np.arange(len(head_pnts))):
604+
assert dev_pnts.shape == head_pnts.shape == (n_coils, 3)
605+
gofs = np.ones(n_coils) if gofs is None else gofs
606+
use_mask = _gof_use_mask(gofs, gof_limit=gof_limit)
607+
n_use = int(use_mask.sum()) # explicit int cast for itertools.permutations
608+
dev_pnts_tmp = dev_pnts[use_mask]
609+
# First pass: figure out best order using the good dev points
610+
for this_order in itertools.permutations(np.arange(len(head_pnts)), n_use):
601611
head_pnts_tmp = head_pnts[np.array(this_order)]
602-
this_quat, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp)
612+
this_quat, g = _fit_chpi_quat(dev_pnts_tmp, head_pnts_tmp)
603613
assert np.linalg.det(quat_to_rot(this_quat[:3])) > 0.9999
604614
if bias:
605615
# For symmetrical arrangements, flips can produce roughly
@@ -612,17 +622,35 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix="")
612622
if check_g > best_g:
613623
out_g = g
614624
best_g = check_g
615-
best_order = np.array(this_order)
625+
best_order[use_mask] = this_order
616626
best_quat = this_quat
627+
del this_order
628+
# Second pass: now fit the remaining (bad) coils using the best order and quat
629+
# from above
630+
missing = np.setdiff1d(np.arange(n_coils), best_order[best_order >= 0])
631+
best_missing_g = -np.inf
632+
for this_order in itertools.permutations(missing):
633+
full_order = best_order.copy()
634+
full_order[~use_mask] = this_order
635+
assert (full_order >= 0).all()
636+
assert np.array_equal(np.sort(full_order), np.arange(n_coils))
637+
head_pnts_tmp = head_pnts[np.array(full_order)]
638+
_, g = _fit_chpi_quat(dev_pnts, head_pnts_tmp, quat=best_quat)
639+
if g > best_missing_g:
640+
best_missing_g = g
641+
best_order[:] = full_order
642+
del this_order
643+
assert np.array_equal(np.sort(best_order), np.arange(n_coils))
617644

618645
# Convert Quaterion to transform
619646
dev_head_t = _quat_to_affine(best_quat)
620647
ang, dist = angle_distance_between_rigid(
621648
dev_head_t, angle_units="deg", distance_units="mm"
622649
)
650+
extra = f" using {n_use}/{n_coils} coils" if n_use < n_coils else ""
623651
logger.info(
624652
f"{prefix}Fitted dev_head_t {ang:0.1f}° and {dist:0.1f} mm "
625-
f"from device origin (GOF: {out_g:.3f})"
653+
f"from device origin{extra} (GOF: {out_g:.3f})"
626654
)
627655
return dev_head_t, best_order, out_g
628656

@@ -1703,7 +1731,8 @@ def refit_hpi(
17031731
:func:`~mne.chpi.compute_chpi_locs`.
17041732
3. Optionally determine coil digitization order by testing all permutations
17051733
for the best goodness of fit between digitized coil locations and
1706-
(rigid-transformed) fitted coil locations.
1734+
(rigid-transformed) fitted coil locations, choosing the order first based on
1735+
those that satisfy ``gof_limit`` then the others.
17071736
4. Subselect coils to use for fitting ``dev_head_t`` based on ``gof_limit``,
17081737
``dist_limit``, and ``use``.
17091738
5. Update info inplace by modifying ``info["dev_head_t"]`` and appending new entries
@@ -1816,6 +1845,8 @@ def refit_hpi(
18161845
fit_dev_head_t, fit_order, _g = _fit_coil_order_dev_head_trans(
18171846
hpi_dev,
18181847
hpi_head,
1848+
gofs=hpi_gofs,
1849+
gof_limit=gof_limit,
18191850
prefix=" ",
18201851
)
18211852
else:
@@ -1824,27 +1855,21 @@ def refit_hpi(
18241855

18251856
# 4. Subselect usable coils and determine final dev_head_t
18261857
if isinstance(use, int) or use is None:
1827-
used = np.where(hpi_gofs >= gof_limit)[0]
1828-
if len(used) < 3:
1829-
gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs)
1830-
raise RuntimeError(
1831-
f"Only {len(used)} coil{_pl(used)} with goodness of fit >= {gof_limit}"
1832-
f", need at least 3 to refit HPI order (got {gofs})."
1833-
)
1834-
quat, _g = _fit_chpi_quat(hpi_dev[used], hpi_head[fit_order][used])
1858+
use_mask = _gof_use_mask(hpi_gofs, gof_limit=gof_limit)
1859+
quat, _g = _fit_chpi_quat(hpi_dev[use_mask], hpi_head[fit_order][use_mask])
18351860
fit_dev_head_t = _quat_to_affine(quat)
18361861
hpi_head_got = apply_trans(fit_dev_head_t, hpi_dev)
18371862
dists = np.linalg.norm(hpi_head_got - hpi_head[fit_order], axis=1)
18381863
dist_str = " ".join(f"{dist * 1e3:.1f}" for dist in dists)
18391864
logger.info(f" Coil distances after initial fit: {dist_str} mm")
1840-
good_dists_idx = np.where(dists[used] <= dist_limit)[0]
1865+
good_dists_idx = np.where(dists[use_mask] <= dist_limit)[0]
18411866
if not len(good_dists_idx) >= 3:
18421867
raise RuntimeError(
1843-
f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} have distance "
1868+
f"Only {len(good_dists_idx)} coil{_pl(good_dists_idx)} with distance "
18441869
f"<= {dist_limit * 1e3:.1f} mm, need at least 3 to refit HPI order "
18451870
f"(got distances: {np.round(1e3 * dists, 1)})."
18461871
)
1847-
used = used[good_dists_idx]
1872+
used = np.where(use_mask)[0][good_dists_idx]
18481873
if use is not None:
18491874
used = np.sort(used[np.argsort(hpi_gofs[used])[-use:]])
18501875
else:
@@ -1927,6 +1952,19 @@ def refit_hpi(
19271952
return info
19281953

19291954

1955+
def _gof_use_mask(hpi_gofs, *, gof_limit):
1956+
assert isinstance(hpi_gofs, np.ndarray) and hpi_gofs.ndim == 1
1957+
use_mask = hpi_gofs >= gof_limit
1958+
n_use = use_mask.sum()
1959+
if n_use < 3:
1960+
gofs = ", ".join(f"{g:.3f}" for g in hpi_gofs)
1961+
raise RuntimeError(
1962+
f"Only {n_use} coil{_pl(n_use)} with goodness of fit >= {gof_limit}"
1963+
f", need at least 3 to refit HPI order (got {gofs})."
1964+
)
1965+
return use_mask
1966+
1967+
19301968
def _sorted_hpi_dig(dig, *, kinds=(FIFF.FIFFV_POINT_HPI,)):
19311969
return sorted(
19321970
# need .get here because the hpi_result["dig_points"] does not set it

mne/datasets/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
# update the checksum in the MNE_DATASETS dict below, and change version
8888
# here: ↓↓↓↓↓↓↓↓
8989
RELEASES = dict(
90-
testing="0.169",
90+
testing="0.170",
9191
misc="0.27",
9292
phantom_kit="0.2",
9393
ucl_opm_auditory="0.2",
@@ -115,7 +115,7 @@
115115
# Testing and misc are at the top as they're updated most often
116116
MNE_DATASETS["testing"] = dict(
117117
archive_name=f"{TESTING_VERSIONED}.tar.gz",
118-
hash="md5:bb0524db8605e96fde6333893a969766",
118+
hash="md5:ebd873ea89507cf5a75043f56119d22b",
119119
url=(
120120
"https://codeload.github.com/mne-tools/mne-testing-data/"
121121
f"tar.gz/{RELEASES['testing']}"

mne/tests/test_chpi.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
ctf_chpi_fname = data_path / "CTF" / "testdata_ctf_mc.ds"
7474
ctf_chpi_pos_fname = data_path / "CTF" / "testdata_ctf_mc.pos"
7575
chpi_problem_fname = data_path / "SSS" / "chpi_problematic-info.fif"
76+
chpi_bad_gof_fname = data_path / "SSS" / "chpi_bad_gof-info.fif"
7677

7778
art_fname = (
7879
data_path
@@ -1011,3 +1012,19 @@ def test_refit_hpi_locs_problematic():
10111012
)
10121013
assert 3 < ang < 6
10131014
assert 82 < dist < 87
1015+
1016+
1017+
@testing.requires_testing_data
1018+
def test_refit_hpi_locs_bad_gof():
1019+
"""Test that we can handle bad GOF HPI fits."""
1020+
# gh-13524
1021+
info = read_info(chpi_bad_gof_fname)
1022+
assert_array_equal(info["hpi_results"][-1]["used"], [2, 3, 4])
1023+
info_new = refit_hpi(info.copy(), amplitudes=False, locs=False)
1024+
assert_array_equal(info_new["hpi_results"][-1]["used"], [1, 2, 3, 4])
1025+
assert_trans_allclose(
1026+
info["dev_head_t"],
1027+
info_new["dev_head_t"],
1028+
dist_tol=1e-3,
1029+
angle_tol=1,
1030+
)

0 commit comments

Comments
 (0)