Skip to content

Commit 3e429f5

Browse files
update
1 parent 00bbc8e commit 3e429f5

2 files changed

Lines changed: 221 additions & 2 deletions

File tree

pydesigner/fitting/dwipy.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
bvalPath: str = None,
7070
mask: str = None,
7171
nthreads: int = -1,
72+
bvec_flips: tuple = (1, -1, 1),
7273
) -> None:
7374
"""DWI class initializer
7475
@@ -117,8 +118,25 @@ def __init__(
117118
# Combine bvecs and bvals into [n x 4] array where n is
118119
# number of DWI volumes. [Gx Gy Gz Bval]
119120
self.grad = np.c_[np.transpose(bvecs), bvals]
121+
122+
# Apply optional bvec sign correction.
123+
# Your test showed that the correct correction is a y-flip: (1, -1, 1).
124+
bvec_flips = np.asarray(bvec_flips, dtype=float)
125+
126+
if bvec_flips.shape != (3,):
127+
raise ValueError(
128+
"bvec_flips must be a 3-element tuple/list, e.g. (1, -1, 1)."
129+
)
130+
131+
if not np.all(np.isin(bvec_flips, [-1, 1])):
132+
raise ValueError(
133+
"bvec_flips values must be either 1 or -1, e.g. (1, -1, 1)."
134+
)
135+
136+
self.grad[:, :3] *= bvec_flips[None, :]
137+
138+
print(f"Applied bvec flips: {tuple(bvec_flips.astype(int))}")
120139
print(f"bvecs, bvals (shape): {np.shape(self.grad)}")
121-
# self.grad = np.c_[bvecs, bvals]
122140
else:
123141
msg = "Unable to locate BVAL or BVEC files"
124142
msg += "\nPaths being used are:"
@@ -873,7 +891,7 @@ def fit(self, constraints: Union[np.ndarray[float], None] = None, reject: bool =
873891
if reject is None:
874892
reject = np.zeros(self.img[:, :, :, exclude_idx].shape)
875893
grad = self.grad[exclude_idx, :]
876-
grad_orig = grad
894+
grad_orig = grad.copy()
877895
order = np.floor(np.log(np.abs(np.max(grad[:, -1]) + 1)) / np.log(10))
878896
img = self.img[:, :, :, exclude_idx]
879897
if order >= 2:

run_pydesigner_fit.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#!/usr/bin/env python
2+
3+
import argparse
4+
import os
5+
import os.path as op
6+
import numpy as np
7+
8+
from pydesigner.fitting.dwipy import DWI
9+
from pydesigner.system.utils import writeNii
10+
from pydesigner.tractography import odf
11+
12+
'''
13+
basePath = os.path.join('/Volumes','Flashy','HIE_FBI_003','FBWM_b4000')
14+
15+
python run_pydesigner_fit_from_preprocessed.py \
16+
--dwi /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/dwi_preprocessed.nii \
17+
--bvec /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/dwi_preprocessed.bvec \
18+
--bval /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/dwi_preprocessed.bval \
19+
--mask /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/brain_mask.nii \
20+
--out /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/metrics \
21+
--nthreads 8 \
22+
--res med \
23+
--lmax-fbi 6 \
24+
--rectify \
25+
--fbwm \
26+
--bvec-flips 1 -1 1
27+
28+
'''
29+
30+
31+
def parse_flips(values):
32+
flips = tuple(float(v) for v in values)
33+
if len(flips) != 3:
34+
raise ValueError("--bvec-flips must have exactly three values, e.g. 1 -1 1")
35+
if any(v not in (-1.0, 1.0) for v in flips):
36+
raise ValueError("--bvec-flips values must be ±1")
37+
return flips
38+
39+
40+
def main():
41+
parser = argparse.ArgumentParser(
42+
description="Run PyDesigner fitting only, starting from dwi_preprocessed.nii."
43+
)
44+
45+
parser.add_argument("--dwi", required=True, help="Path to dwi_preprocessed.nii or .nii.gz")
46+
parser.add_argument("--bvec", required=True, help="Path to dwi_preprocessed.bvec")
47+
parser.add_argument("--bval", required=True, help="Path to dwi_preprocessed.bval")
48+
parser.add_argument("--mask", required=True, help="Path to brain_mask.nii or .nii.gz")
49+
parser.add_argument("--out", required=True, help="Output directory for metrics")
50+
parser.add_argument("--nthreads", type=int, default=1)
51+
parser.add_argument("--res", default="med", choices=["low", "med", "high"])
52+
parser.add_argument("--lmax-fbi", type=int, default=6)
53+
parser.add_argument("--rectify", action="store_true", help="Rectify FBI fODF")
54+
parser.add_argument("--fbwm", action="store_true", help="Run FBWM metrics if FBI + DKI are available")
55+
parser.add_argument(
56+
"--bvec-flips",
57+
nargs=3,
58+
default=(1, -1, 1),
59+
help="Gradient sign flips for x y z. Default is 1 -1 1 based on your y-flip test.",
60+
)
61+
62+
args = parser.parse_args()
63+
flips = parse_flips(args.bvec_flips)
64+
65+
os.makedirs(args.out, exist_ok=True)
66+
67+
print("\nLoading preprocessed DWI...")
68+
img = DWI(
69+
imPath=args.dwi,
70+
bvecPath=args.bvec,
71+
bvalPath=args.bval,
72+
mask=args.mask,
73+
nthreads=args.nthreads,
74+
)
75+
76+
print(f"Applying bvec flips: {flips}")
77+
img.grad[:, :3] *= np.asarray(flips)[None, :]
78+
79+
print("Detected protocols:", img.tensorType())
80+
81+
print("\nFitting tensor model...")
82+
img.fit(constraints=None)
83+
84+
tensor_type = "dki" if img.isdki() else "dti"
85+
print(f"Tensor type used for tensorReorder(): {tensor_type}")
86+
87+
DT, KT = img.tensorReorder(tensor_type)
88+
89+
dt_path = op.join(args.out, "DT.nii")
90+
kt_path = op.join(args.out, "KT.nii")
91+
92+
writeNii(DT, img.hdr, dt_path)
93+
print(f"Wrote {dt_path}")
94+
95+
if tensor_type == "dki":
96+
writeNii(KT, img.hdr, kt_path)
97+
print(f"Wrote {kt_path}")
98+
99+
# ------------------------------------------------------------
100+
# DTI metrics + DTI ODF SH
101+
# ------------------------------------------------------------
102+
if img.isdti():
103+
print("\nExtracting DTI metrics...")
104+
md, rd, ad, fa, fe, trace = img.extractDTI()
105+
106+
writeNii(md, img.hdr, op.join(args.out, "dti_md.nii"))
107+
writeNii(rd, img.hdr, op.join(args.out, "dti_rd.nii"))
108+
writeNii(ad, img.hdr, op.join(args.out, "dti_ad.nii"))
109+
writeNii(fa, img.hdr, op.join(args.out, "dti_fa.nii"))
110+
writeNii(fe, img.hdr, op.join(args.out, "dti_fe.nii"))
111+
writeNii(trace, img.hdr, op.join(args.out, "dti_trace.nii"))
112+
113+
print("Computing DTI ODF SH...")
114+
dti_model = odf.odfmodel(
115+
dt=dt_path,
116+
mask=args.mask,
117+
l_max=2,
118+
res=args.res,
119+
nthreads=args.nthreads,
120+
)
121+
dti_odfs = dti_model.dtiodf()
122+
dti_sh = dti_model.odf2sh(dti_odfs)
123+
dti_model.savenii(dti_sh, op.join(args.out, "dti_odf.nii"))
124+
125+
# ------------------------------------------------------------
126+
# DKI metrics + DKI ODF SH
127+
# ------------------------------------------------------------
128+
if img.isdki():
129+
print("\nExtracting DKI metrics...")
130+
mk, rk, ak, kfa, mkt, trace = img.extractDKI()
131+
132+
writeNii(mk, img.hdr, op.join(args.out, "dki_mk.nii"))
133+
writeNii(rk, img.hdr, op.join(args.out, "dki_rk.nii"))
134+
writeNii(ak, img.hdr, op.join(args.out, "dki_ak.nii"))
135+
writeNii(kfa, img.hdr, op.join(args.out, "dki_kfa.nii"))
136+
writeNii(mkt, img.hdr, op.join(args.out, "dki_mkt.nii"))
137+
writeNii(trace, img.hdr, op.join(args.out, "dki_trace.nii"))
138+
139+
print("Computing DKI ODF SH...")
140+
dki_model = odf.odfmodel(
141+
dt=dt_path,
142+
kt=kt_path,
143+
mask=args.mask,
144+
l_max=6,
145+
res=args.res,
146+
nthreads=args.nthreads,
147+
)
148+
dki_odfs = dki_model.dkiodf(fa_t=0.90)
149+
dki_sh = dki_model.odf2sh(dki_odfs)
150+
dki_model.savenii(dki_sh, op.join(args.out, "dki_odf.nii"))
151+
152+
# ------------------------------------------------------------
153+
# FBI / FBWM
154+
# ------------------------------------------------------------
155+
if img.isfbi():
156+
print("\nRunning FBI / FBWM...")
157+
run_fbwm = bool(args.fbwm and img.isfbwm())
158+
159+
(
160+
zeta,
161+
faa,
162+
sph,
163+
sph_mrtrix,
164+
awf,
165+
Da,
166+
De_mean,
167+
De_ax,
168+
De_rad,
169+
De_fa,
170+
min_cost,
171+
min_cost_fn,
172+
) = img.fbi(
173+
l_max=args.lmax_fbi,
174+
fbwm=run_fbwm,
175+
rectify=args.rectify,
176+
res=args.res,
177+
)
178+
179+
writeNii(zeta, img.hdr, op.join(args.out, "fbi_zeta.nii"))
180+
writeNii(faa, img.hdr, op.join(args.out, "fbi_faa.nii"))
181+
182+
# Save both. For MRtrix, use the MRtrix/Tournier-converted version.
183+
# writeNii(np.real(sph), img.hdr, op.join(args.out, "fbi_odf_raw.nii"))
184+
writeNii(np.real(sph_mrtrix), img.hdr, op.join(args.out, "fbi_odf.nii"))
185+
writeNii(np.real(sph_mrtrix), img.hdr, op.join(args.out, "fbi_odf_mrtrix.nii"))
186+
187+
if run_fbwm:
188+
writeNii(awf, img.hdr, op.join(args.out, "fbwm_awf.nii"))
189+
writeNii(Da, img.hdr, op.join(args.out, "fbwm_Da.nii"))
190+
writeNii(De_mean, img.hdr, op.join(args.out, "fbwm_De_mean.nii"))
191+
writeNii(De_ax, img.hdr, op.join(args.out, "fbwm_De_ax.nii"))
192+
writeNii(De_rad, img.hdr, op.join(args.out, "fbwm_De_rad.nii"))
193+
writeNii(De_fa, img.hdr, op.join(args.out, "fbwm_fae.nii"))
194+
writeNii(min_cost, img.hdr, op.join(args.out, "fbwm_minCost.nii"))
195+
writeNii(min_cost_fn, img.hdr, op.join(args.out, "fbwm_minCostFn.nii"))
196+
197+
print("\nDone.")
198+
199+
200+
if __name__ == "__main__":
201+
main()

0 commit comments

Comments
 (0)