Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions mne/preprocessing/fit_spheres_to_mri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
def fit_spheres_to_mri(
subjects_dir, subject, bem_surf, trans, n_spheres, show_spheres=False
):
"""Fits two spheres to MRI using BEM, such that spheres fit while brain but
do not encroach on sensors. For use with Milti-SSS Maxwell Filtering

Parameters
----------
subjects dir: str
director to Freesurfer subjects
subject: str
Subject ID
bem_surf: list
output of mne.make_bem_model(), must be three shell conductivity profiles
trans: str
path to trans file, mri_dev_t information
n_spheres: int
number of spheres to fit, recommended 2
show_spheres: bool
show pyvista plot of the origins and optimized spheres overlayed with the head

Returns
-------
centers: np.ndarray
2x3 array containing the two centers in HEAD coordinate space
can be directly fed into:
raw_msss = mne.preprocessing.maxwell_filter(raw, origin=origins, ...)
for multi-SSS preprocessing


Notes
-----
* Must have vedo and nibabel installed
* Must have run mne watershed BEM using freesurfer segmentation
"""
## --- required imports

import os

import nibabel as nib
import numpy as np
import vedo
from scipy.spatial import KDTree

from .._fiff.constants import FIFF
from ..surface import _CheckInside
from ..transforms import (
apply_trans,
invert_transform,
read_trans,
)

## --- begin
mindist = 2e-3
assert bem_surf[0]["id"] == FIFF.FIFFV_BEM_SURF_ID_HEAD
assert bem_surf[2]["id"] == FIFF.FIFFV_BEM_SURF_ID_BRAIN
scalp, _, inner_skull = bem_surf
inside_scalp = _CheckInside(scalp, mode="pyvista")
inside_skull = _CheckInside(inner_skull, mode="pyvista")
m3_to_cc = 100**3
assert inside_scalp(inner_skull["rr"]).all()
assert not inside_skull(scalp["rr"]).any()
b = vedo.Mesh([inner_skull["rr"], inner_skull["tris"]])
s = vedo.Mesh([scalp["rr"], scalp["tris"]])
s_tree = KDTree(scalp["rr"])
brain_volume = b.volume()
print(f"Brain vedo: {brain_volume * m3_to_cc:8.2f} cc")
brain_vol = nib.load(os.path.join(subjects_dir, subject, "mri", "brainmask.mgz"))
brain_rr = np.array(np.where(brain_vol.get_fdata())).T
brain_rr = (
apply_trans(brain_vol.header.get_vox2ras_tkr(), brain_rr) / 1000.0
) # apply a transformation matrix
del brain_vol # delete brain volume
brain_rr = brain_rr[inside_skull(brain_rr)]
vox_to_m3 = 1e-9
brain_volume_vox = len(brain_rr) * vox_to_m3

def _print_q(title, got, want):
title = f"{title}:".ljust(15)
print(f"{title} {got * m3_to_cc:8.2f} cc ({(want - got) / want * 100:6.2f} %)")

_print_q("Brain vox", brain_volume_vox, brain_volume_vox)

# 1. Compute a naive sphere using the center of mass of brain surf verts
naive_c = np.mean(inner_skull["rr"], axis=0)
naive_r = np.min(np.linalg.norm(inner_skull["rr"] - naive_c, axis=1))
naive_v = 4 / 3 * np.pi * naive_r**3
_print_q("Naive sphere", naive_v, brain_volume)
s1 = vedo.Sphere(naive_c, naive_r, res=100)
_print_q("Naive vedo", s1.volume(), brain_volume)

# 2. Now use the larger radius (to head) plus mesh arithmetic
better_r = s_tree.query(naive_c)[0] - mindist
s1 = vedo.Sphere(naive_c, better_r, res=24)
_print_q("Better vedo", s1.boolean("intersect", b).volume(), brain_volume)
v = np.sum(np.linalg.norm(brain_rr - naive_c, axis=1) <= better_r) * vox_to_m3
_print_q("Better vox", v, brain_volume_vox)

# 3. Now optimize one sphere
from scipy.optimize import (
fmin_cobyla, # constrained optimization by linear approximation
)

def _cost(c):
cs = c.reshape(-1, 3)
rs = np.maximum(s_tree.query(cs)[0] - mindist, 0.0)
resid = brain_volume
mask = None
for c, r in zip(cs, rs):
if not (r and s.contains(c)): # was is_inside
continue
m = np.linalg.norm(brain_rr - c, axis=1) <= r
if mask is None:
mask = m
else:
mask |= m
resid = brain_volume_vox
if mask is not None:
resid = resid - np.sum(mask) * vox_to_m3
return resid

def _cons(c):
cs = c.reshape(-1, 3)
sign = np.array([2 * s.contains(c) - 1 for c in cs], float) # was "is_inside"
cons = sign * s_tree.query(cs)[0] - mindist
return cons

x = naive_c
c_opt_1 = fmin_cobyla(_cost, x, _cons, rhobeg=1e-2, rhoend=1e-4)
v_opt_1 = brain_volume_vox - _cost(c_opt_1)
_print_q("COBYLA 1", v_opt_1, brain_volume_vox)

# 4. Now optimize two spheres
x = np.concatenate([c_opt_1, naive_c])
c_opt_2 = fmin_cobyla(_cost, x, _cons, rhobeg=1e-2, rhoend=1e-4)
v_opt_2 = brain_volume_vox - _cost(c_opt_2)
_print_q("COBYLA 2", v_opt_2, brain_volume_vox)

# 4. Finally, three spheres (not perfect, not global opt)
x = np.concatenate([c_opt_2, naive_c])
c_opt_3 = fmin_cobyla(_cost, x, _cons, rhobeg=1e-2, rhoend=1e-4)
v_opt_3 = brain_volume_vox - _cost(c_opt_3)
_print_q("COBYLA 3", v_opt_3, brain_volume_vox)

if show_spheres:
import matplotlib
import pyvista as pv
import pyvistaqt

plotter = pyvistaqt.BackgroundPlotter(
shape=(1, 2),
window_size=(1200, 300),
editor=False,
menu_bar=False,
toolbar=False,
)
plotter.background_color = "w"
brain_mesh = pv.make_tri_mesh(inner_skull["rr"], inner_skull["tris"])
scalp_mesh = pv.make_tri_mesh(scalp["rr"], scalp["tris"])
colors = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"]
mesh_kwargs = dict(render=False, reset_camera=False, smooth_shading=True)
for ci, cs in enumerate((c_opt_1, c_opt_2, c_opt_3)):
plotter.subplot(0, ci)
plotter.camera.position = (0.0, -0.5, 0)
plotter.camera.focal_point = (0.0, 0.0, 0.0)
plotter.camera.azimuth = 90
plotter.camera.elevation = 0
plotter.camera.up = (0.0, 0.0, 1.0)
plotter.add_mesh(brain_mesh, opacity=0.2, color="k", **mesh_kwargs)
plotter.add_mesh(scalp_mesh, opacity=0.1, color="tan", **mesh_kwargs)
for c, color in zip(cs.reshape(-1, 3), colors):
sphere = pv.Sphere(s_tree.query(c)[0] - mindist, c)
plotter.add_mesh(sphere, opacity=0.5, color=color, **mesh_kwargs)
plotter.show()

# Ready centers to output, transform into device space
mri_head_t = invert_transform(read_trans(trans))
if mri_head_t["from"] == FIFF.FIFFV_COORD_HEAD:
mri_head_t = invert_transform(mri_head_t)
assert mri_head_t["from"] == FIFF.FIFFV_COORD_MRI, mri_head_t["from"]
centers = []
for use in (c_opt_1, c_opt_2, c_opt_3):
centers.append(apply_trans(mri_head_t, use.reshape(-1, 3)))
if n_spheres == 1:
return centers[0]
if n_spheres == 2:
return centers[1]
if n_spheres == 3:
print(
"Warning: use of mSSS with three origins and expansions is not tested or recommended"
)
return centers[2]
108 changes: 80 additions & 28 deletions mne/preprocessing/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,34 +403,41 @@ def maxwell_filter(
.. footbibliography::
""" # noqa: E501
logger.info("Maxwell filtering raw data")
params = _prep_maxwell_filter(
raw=raw,
origin=origin,
int_order=int_order,
ext_order=ext_order,
calibration=calibration,
cross_talk=cross_talk,
st_duration=st_duration,
st_correlation=st_correlation,
coord_frame=coord_frame,
destination=destination,
regularize=regularize,
ignore_ref=ignore_ref,
bad_condition=bad_condition,
head_pos=head_pos,
st_fixed=st_fixed,
st_only=st_only,
mag_scale=mag_scale,
skip_by_annotation=skip_by_annotation,
extended_proj=extended_proj,
st_overlap=st_overlap,
mc_interp=mc_interp,
)
raw_sss = _run_maxwell_filter(raw, **params)
# Update info
_update_sss_info(raw_sss, **params["update_kwargs"])
logger.info("[done]")
return raw_sss
# TODO: fix
if isinstance(origin, np.ndarray):
raw_mSSS = _run_mSSS(raw, origin)
# Update info _update_sss_info(raw_sss, **params["update_kwargs"]) ??
return raw_mSSS

else:
params = _prep_maxwell_filter(
raw=raw,
origin=origin,
int_order=int_order,
ext_order=ext_order,
calibration=calibration,
cross_talk=cross_talk,
st_duration=st_duration,
st_correlation=st_correlation,
coord_frame=coord_frame,
destination=destination,
regularize=regularize,
ignore_ref=ignore_ref,
bad_condition=bad_condition,
head_pos=head_pos,
st_fixed=st_fixed,
st_only=st_only,
mag_scale=mag_scale,
skip_by_annotation=skip_by_annotation,
extended_proj=extended_proj,
st_overlap=st_overlap,
mc_interp=mc_interp,
)
raw_sss = _run_maxwell_filter(raw, **params)
# Update info
_update_sss_info(raw_sss, **params["update_kwargs"])
logger.info("[done]")
return raw_sss


@verbose
Expand Down Expand Up @@ -834,6 +841,37 @@ def _run_maxwell_filter(
return raw_sss


def _run_mSSS(raw, origin):
if len(origin) > 2:
# TODO: fix error msg
raise ValueError("n > 2")

S_in = []
for i in range(len(origin)):
[S_i, _, _, moments_i] = compute_maxwell_basis(
raw.info, origin=origin[i], regularize=None, bad_condition="ignore"
)
S_in.append(S_i[:, :moments_i])

S_tot = _combine_sss_basis(S_in[0], S_in[1])

raw_mSSS = maxwell_filter(
raw, origin="auto", regularize=None, ignore_ref=True, bad_condition="ignore"
)

meg_picks = pick_types(raw_mSSS.info, meg=True)

# reconstruct
phi_0 = raw.get_data(picks="meg")
pS = np.linalg.pinv(S_tot)
XN = pS @ phi_0
mSSS_data = np.real(S_tot @ XN)

raw_mSSS._data[meg_picks] = mSSS_data

return raw_mSSS


class _MoveComp:
"""Perform movement compensation."""

Expand Down Expand Up @@ -1866,6 +1904,20 @@ def _sss_basis(exp, all_coils):
return S_tot


def _combine_sss_basis(S_in1, S_in2):
"""MSSS calculations using optimized multi-centers"""
# TODO: n > 2
S_tot = []
thresh = 5e-7 # 0.005 in Matlab
U, s, Vh = np.linalg.svd(np.concatenate((S_in1, S_in2), axis=1))
# apply threshold to limit dimensions of resulting basis
for i in range(0, np.shape(s)[0]):
ratio = s[i] / s[0]
if ratio >= thresh:
S_tot.append(U[:, i])
return np.transpose(np.array(S_tot))


def _integrate_points(
cos_az, sin_az, cos_pol, sin_pol, b_r, b_az, b_pol, cosmags, bins, n_coils
):
Expand Down
Loading