@@ -579,27 +579,37 @@ def _chpi_objective(x, coil_dev_rrs, coil_head_rrs):
579579 return d .sum ()
580580
581581
582- def _fit_chpi_quat (coil_dev_rrs , coil_head_rrs ):
582+ def _fit_chpi_quat (coil_dev_rrs , coil_head_rrs , * , quat = None ):
583583 """Fit rotation and translation (quaternion) parameters for cHPI coils."""
584584 denom = np .linalg .norm (coil_head_rrs - np .mean (coil_head_rrs , axis = 0 ))
585585 denom *= denom
586586 # We could try to solve it the analytic way:
587587 # TODO someday we could choose to weight these points by their goodness
588588 # of fit somehow, see also https://github.com/mne-tools/mne-python/issues/11330
589- quat = _fit_matched_points (coil_dev_rrs , coil_head_rrs )[0 ]
589+ if quat is None :
590+ quat = _fit_matched_points (coil_dev_rrs , coil_head_rrs )[0 ]
590591 gof = 1.0 - _chpi_objective (quat , coil_dev_rrs , coil_head_rrs ) / denom
591592 return quat , gof
592593
593594
594- def _fit_coil_order_dev_head_trans (dev_pnts , head_pnts , * , bias = True , prefix = "" ):
595+ def _fit_coil_order_dev_head_trans (
596+ dev_pnts , head_pnts , * , bias = True , gofs = None , gof_limit = 0.98 , prefix = ""
597+ ):
595598 """Compute Device to Head transform allowing for permutiatons of points."""
599+ n_coils = len (dev_pnts )
596600 id_quat = np .zeros (6 )
597- best_order = None
601+ best_order = np . full ( n_coils , - 1 , dtype = int )
598602 best_g = - 999
599603 best_quat = id_quat
600- for this_order in itertools .permutations (np .arange (len (head_pnts ))):
604+ assert dev_pnts .shape == head_pnts .shape == (n_coils , 3 )
605+ gofs = np .ones (n_coils ) if gofs is None else gofs
606+ use_mask = _gof_use_mask (gofs , gof_limit = gof_limit )
607+ n_use = int (use_mask .sum ()) # explicit int cast for itertools.permutations
608+ dev_pnts_tmp = dev_pnts [use_mask ]
609+ # First pass: figure out best order using the good dev points
610+ for this_order in itertools .permutations (np .arange (len (head_pnts )), n_use ):
601611 head_pnts_tmp = head_pnts [np .array (this_order )]
602- this_quat , g = _fit_chpi_quat (dev_pnts , head_pnts_tmp )
612+ this_quat , g = _fit_chpi_quat (dev_pnts_tmp , head_pnts_tmp )
603613 assert np .linalg .det (quat_to_rot (this_quat [:3 ])) > 0.9999
604614 if bias :
605615 # For symmetrical arrangements, flips can produce roughly
@@ -612,17 +622,35 @@ def _fit_coil_order_dev_head_trans(dev_pnts, head_pnts, *, bias=True, prefix="")
612622 if check_g > best_g :
613623 out_g = g
614624 best_g = check_g
615- best_order = np . array ( this_order )
625+ best_order [ use_mask ] = this_order
616626 best_quat = this_quat
627+ del this_order
628+ # Second pass: now fit the remaining (bad) coils using the best order and quat
629+ # from above
630+ missing = np .setdiff1d (np .arange (n_coils ), best_order [best_order >= 0 ])
631+ best_missing_g = - np .inf
632+ for this_order in itertools .permutations (missing ):
633+ full_order = best_order .copy ()
634+ full_order [~ use_mask ] = this_order
635+ assert (full_order >= 0 ).all ()
636+ assert np .array_equal (np .sort (full_order ), np .arange (n_coils ))
637+ head_pnts_tmp = head_pnts [np .array (full_order )]
638+ _ , g = _fit_chpi_quat (dev_pnts , head_pnts_tmp , quat = best_quat )
639+ if g > best_missing_g :
640+ best_missing_g = g
641+ best_order [:] = full_order
642+ del this_order
643+ assert np .array_equal (np .sort (best_order ), np .arange (n_coils ))
617644
618645 # Convert Quaterion to transform
619646 dev_head_t = _quat_to_affine (best_quat )
620647 ang , dist = angle_distance_between_rigid (
621648 dev_head_t , angle_units = "deg" , distance_units = "mm"
622649 )
650+ extra = f" using { n_use } /{ n_coils } coils" if n_use < n_coils else ""
623651 logger .info (
624652 f"{ prefix } Fitted dev_head_t { ang :0.1f} ° and { dist :0.1f} mm "
625- f"from device origin (GOF: { out_g :.3f} )"
653+ f"from device origin{ extra } (GOF: { out_g :.3f} )"
626654 )
627655 return dev_head_t , best_order , out_g
628656
@@ -1703,7 +1731,8 @@ def refit_hpi(
17031731 :func:`~mne.chpi.compute_chpi_locs`.
17041732 3. Optionally determine coil digitization order by testing all permutations
17051733 for the best goodness of fit between digitized coil locations and
1706- (rigid-transformed) fitted coil locations.
1734+ (rigid-transformed) fitted coil locations, choosing the order first based on
1735+ those that satisfy ``gof_limit`` then the others.
17071736 4. Subselect coils to use for fitting ``dev_head_t`` based on ``gof_limit``,
17081737 ``dist_limit``, and ``use``.
17091738 5. Update info inplace by modifying ``info["dev_head_t"]`` and appending new entries
@@ -1816,6 +1845,8 @@ def refit_hpi(
18161845 fit_dev_head_t , fit_order , _g = _fit_coil_order_dev_head_trans (
18171846 hpi_dev ,
18181847 hpi_head ,
1848+ gofs = hpi_gofs ,
1849+ gof_limit = gof_limit ,
18191850 prefix = " " ,
18201851 )
18211852 else :
@@ -1824,27 +1855,21 @@ def refit_hpi(
18241855
18251856 # 4. Subselect usable coils and determine final dev_head_t
18261857 if isinstance (use , int ) or use is None :
1827- used = np .where (hpi_gofs >= gof_limit )[0 ]
1828- if len (used ) < 3 :
1829- gofs = ", " .join (f"{ g :.3f} " for g in hpi_gofs )
1830- raise RuntimeError (
1831- f"Only { len (used )} coil{ _pl (used )} with goodness of fit >= { gof_limit } "
1832- f", need at least 3 to refit HPI order (got { gofs } )."
1833- )
1834- quat , _g = _fit_chpi_quat (hpi_dev [used ], hpi_head [fit_order ][used ])
1858+ use_mask = _gof_use_mask (hpi_gofs , gof_limit = gof_limit )
1859+ quat , _g = _fit_chpi_quat (hpi_dev [use_mask ], hpi_head [fit_order ][use_mask ])
18351860 fit_dev_head_t = _quat_to_affine (quat )
18361861 hpi_head_got = apply_trans (fit_dev_head_t , hpi_dev )
18371862 dists = np .linalg .norm (hpi_head_got - hpi_head [fit_order ], axis = 1 )
18381863 dist_str = " " .join (f"{ dist * 1e3 :.1f} " for dist in dists )
18391864 logger .info (f" Coil distances after initial fit: { dist_str } mm" )
1840- good_dists_idx = np .where (dists [used ] <= dist_limit )[0 ]
1865+ good_dists_idx = np .where (dists [use_mask ] <= dist_limit )[0 ]
18411866 if not len (good_dists_idx ) >= 3 :
18421867 raise RuntimeError (
1843- f"Only { len (good_dists_idx )} coil{ _pl (good_dists_idx )} have distance "
1868+ f"Only { len (good_dists_idx )} coil{ _pl (good_dists_idx )} with distance "
18441869 f"<= { dist_limit * 1e3 :.1f} mm, need at least 3 to refit HPI order "
18451870 f"(got distances: { np .round (1e3 * dists , 1 )} )."
18461871 )
1847- used = used [good_dists_idx ]
1872+ used = np . where ( use_mask )[ 0 ] [good_dists_idx ]
18481873 if use is not None :
18491874 used = np .sort (used [np .argsort (hpi_gofs [used ])[- use :]])
18501875 else :
@@ -1927,6 +1952,19 @@ def refit_hpi(
19271952 return info
19281953
19291954
1955+ def _gof_use_mask (hpi_gofs , * , gof_limit ):
1956+ assert isinstance (hpi_gofs , np .ndarray ) and hpi_gofs .ndim == 1
1957+ use_mask = hpi_gofs >= gof_limit
1958+ n_use = use_mask .sum ()
1959+ if n_use < 3 :
1960+ gofs = ", " .join (f"{ g :.3f} " for g in hpi_gofs )
1961+ raise RuntimeError (
1962+ f"Only { n_use } coil{ _pl (n_use )} with goodness of fit >= { gof_limit } "
1963+ f", need at least 3 to refit HPI order (got { gofs } )."
1964+ )
1965+ return use_mask
1966+
1967+
19301968def _sorted_hpi_dig (dig , * , kinds = (FIFF .FIFFV_POINT_HPI ,)):
19311969 return sorted (
19321970 # need .get here because the hpi_result["dig_points"] does not set it
0 commit comments