1+ #!/usr/bin/env python
2+
3+ import argparse
4+ import os
5+ import os .path as op
6+ import numpy as np
7+
8+ from pydesigner .fitting .dwipy import DWI
9+ from pydesigner .system .utils import writeNii
10+ from pydesigner .tractography import odf
11+
12+ '''
13+ basePath = os.path.join('/Volumes','Flashy','HIE_FBI_003','FBWM_b4000')
14+
15+ python run_pydesigner_fit_from_preprocessed.py \
16+ --dwi /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/dwi_preprocessed.nii \
17+ --bvec /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/dwi_preprocessed.bvec \
18+ --bval /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/dwi_preprocessed.bval \
19+ --mask /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/brain_mask.nii \
20+ --out /Volumes/Flashy/HIE_FBI_003/FBWM_b4000/metrics \
21+ --nthreads 8 \
22+ --res med \
23+ --lmax-fbi 6 \
24+ --rectify \
25+ --fbwm \
26+ --bvec-flips 1 -1 1
27+
28+ '''
29+
30+
31+ def parse_flips (values ):
32+ flips = tuple (float (v ) for v in values )
33+ if len (flips ) != 3 :
34+ raise ValueError ("--bvec-flips must have exactly three values, e.g. 1 -1 1" )
35+ if any (v not in (- 1.0 , 1.0 ) for v in flips ):
36+ raise ValueError ("--bvec-flips values must be ±1" )
37+ return flips
38+
39+
40+ def main ():
41+ parser = argparse .ArgumentParser (
42+ description = "Run PyDesigner fitting only, starting from dwi_preprocessed.nii."
43+ )
44+
45+ parser .add_argument ("--dwi" , required = True , help = "Path to dwi_preprocessed.nii or .nii.gz" )
46+ parser .add_argument ("--bvec" , required = True , help = "Path to dwi_preprocessed.bvec" )
47+ parser .add_argument ("--bval" , required = True , help = "Path to dwi_preprocessed.bval" )
48+ parser .add_argument ("--mask" , required = True , help = "Path to brain_mask.nii or .nii.gz" )
49+ parser .add_argument ("--out" , required = True , help = "Output directory for metrics" )
50+ parser .add_argument ("--nthreads" , type = int , default = 1 )
51+ parser .add_argument ("--res" , default = "med" , choices = ["low" , "med" , "high" ])
52+ parser .add_argument ("--lmax-fbi" , type = int , default = 6 )
53+ parser .add_argument ("--rectify" , action = "store_true" , help = "Rectify FBI fODF" )
54+ parser .add_argument ("--fbwm" , action = "store_true" , help = "Run FBWM metrics if FBI + DKI are available" )
55+ parser .add_argument (
56+ "--bvec-flips" ,
57+ nargs = 3 ,
58+ default = (1 , - 1 , 1 ),
59+ help = "Gradient sign flips for x y z. Default is 1 -1 1 based on your y-flip test." ,
60+ )
61+
62+ args = parser .parse_args ()
63+ flips = parse_flips (args .bvec_flips )
64+
65+ os .makedirs (args .out , exist_ok = True )
66+
67+ print ("\n Loading preprocessed DWI..." )
68+ img = DWI (
69+ imPath = args .dwi ,
70+ bvecPath = args .bvec ,
71+ bvalPath = args .bval ,
72+ mask = args .mask ,
73+ nthreads = args .nthreads ,
74+ )
75+
76+ print (f"Applying bvec flips: { flips } " )
77+ img .grad [:, :3 ] *= np .asarray (flips )[None , :]
78+
79+ print ("Detected protocols:" , img .tensorType ())
80+
81+ print ("\n Fitting tensor model..." )
82+ img .fit (constraints = None )
83+
84+ tensor_type = "dki" if img .isdki () else "dti"
85+ print (f"Tensor type used for tensorReorder(): { tensor_type } " )
86+
87+ DT , KT = img .tensorReorder (tensor_type )
88+
89+ dt_path = op .join (args .out , "DT.nii" )
90+ kt_path = op .join (args .out , "KT.nii" )
91+
92+ writeNii (DT , img .hdr , dt_path )
93+ print (f"Wrote { dt_path } " )
94+
95+ if tensor_type == "dki" :
96+ writeNii (KT , img .hdr , kt_path )
97+ print (f"Wrote { kt_path } " )
98+
99+ # ------------------------------------------------------------
100+ # DTI metrics + DTI ODF SH
101+ # ------------------------------------------------------------
102+ if img .isdti ():
103+ print ("\n Extracting DTI metrics..." )
104+ md , rd , ad , fa , fe , trace = img .extractDTI ()
105+
106+ writeNii (md , img .hdr , op .join (args .out , "dti_md.nii" ))
107+ writeNii (rd , img .hdr , op .join (args .out , "dti_rd.nii" ))
108+ writeNii (ad , img .hdr , op .join (args .out , "dti_ad.nii" ))
109+ writeNii (fa , img .hdr , op .join (args .out , "dti_fa.nii" ))
110+ writeNii (fe , img .hdr , op .join (args .out , "dti_fe.nii" ))
111+ writeNii (trace , img .hdr , op .join (args .out , "dti_trace.nii" ))
112+
113+ print ("Computing DTI ODF SH..." )
114+ dti_model = odf .odfmodel (
115+ dt = dt_path ,
116+ mask = args .mask ,
117+ l_max = 2 ,
118+ res = args .res ,
119+ nthreads = args .nthreads ,
120+ )
121+ dti_odfs = dti_model .dtiodf ()
122+ dti_sh = dti_model .odf2sh (dti_odfs )
123+ dti_model .savenii (dti_sh , op .join (args .out , "dti_odf.nii" ))
124+
125+ # ------------------------------------------------------------
126+ # DKI metrics + DKI ODF SH
127+ # ------------------------------------------------------------
128+ if img .isdki ():
129+ print ("\n Extracting DKI metrics..." )
130+ mk , rk , ak , kfa , mkt , trace = img .extractDKI ()
131+
132+ writeNii (mk , img .hdr , op .join (args .out , "dki_mk.nii" ))
133+ writeNii (rk , img .hdr , op .join (args .out , "dki_rk.nii" ))
134+ writeNii (ak , img .hdr , op .join (args .out , "dki_ak.nii" ))
135+ writeNii (kfa , img .hdr , op .join (args .out , "dki_kfa.nii" ))
136+ writeNii (mkt , img .hdr , op .join (args .out , "dki_mkt.nii" ))
137+ writeNii (trace , img .hdr , op .join (args .out , "dki_trace.nii" ))
138+
139+ print ("Computing DKI ODF SH..." )
140+ dki_model = odf .odfmodel (
141+ dt = dt_path ,
142+ kt = kt_path ,
143+ mask = args .mask ,
144+ l_max = 6 ,
145+ res = args .res ,
146+ nthreads = args .nthreads ,
147+ )
148+ dki_odfs = dki_model .dkiodf (fa_t = 0.90 )
149+ dki_sh = dki_model .odf2sh (dki_odfs )
150+ dki_model .savenii (dki_sh , op .join (args .out , "dki_odf.nii" ))
151+
152+ # ------------------------------------------------------------
153+ # FBI / FBWM
154+ # ------------------------------------------------------------
155+ if img .isfbi ():
156+ print ("\n Running FBI / FBWM..." )
157+ run_fbwm = bool (args .fbwm and img .isfbwm ())
158+
159+ (
160+ zeta ,
161+ faa ,
162+ sph ,
163+ sph_mrtrix ,
164+ awf ,
165+ Da ,
166+ De_mean ,
167+ De_ax ,
168+ De_rad ,
169+ De_fa ,
170+ min_cost ,
171+ min_cost_fn ,
172+ ) = img .fbi (
173+ l_max = args .lmax_fbi ,
174+ fbwm = run_fbwm ,
175+ rectify = args .rectify ,
176+ res = args .res ,
177+ )
178+
179+ writeNii (zeta , img .hdr , op .join (args .out , "fbi_zeta.nii" ))
180+ writeNii (faa , img .hdr , op .join (args .out , "fbi_faa.nii" ))
181+
182+ # Save both. For MRtrix, use the MRtrix/Tournier-converted version.
183+ # writeNii(np.real(sph), img.hdr, op.join(args.out, "fbi_odf_raw.nii"))
184+ writeNii (np .real (sph_mrtrix ), img .hdr , op .join (args .out , "fbi_odf.nii" ))
185+ writeNii (np .real (sph_mrtrix ), img .hdr , op .join (args .out , "fbi_odf_mrtrix.nii" ))
186+
187+ if run_fbwm :
188+ writeNii (awf , img .hdr , op .join (args .out , "fbwm_awf.nii" ))
189+ writeNii (Da , img .hdr , op .join (args .out , "fbwm_Da.nii" ))
190+ writeNii (De_mean , img .hdr , op .join (args .out , "fbwm_De_mean.nii" ))
191+ writeNii (De_ax , img .hdr , op .join (args .out , "fbwm_De_ax.nii" ))
192+ writeNii (De_rad , img .hdr , op .join (args .out , "fbwm_De_rad.nii" ))
193+ writeNii (De_fa , img .hdr , op .join (args .out , "fbwm_fae.nii" ))
194+ writeNii (min_cost , img .hdr , op .join (args .out , "fbwm_minCost.nii" ))
195+ writeNii (min_cost_fn , img .hdr , op .join (args .out , "fbwm_minCostFn.nii" ))
196+
197+ print ("\n Done." )
198+
199+
200+ if __name__ == "__main__" :
201+ main ()
0 commit comments