33import IVIMNET .deep as deep
44import torch
55import warnings
6+ from utilities .data_simulation .GenerateData import GenerateData
67
78class IVIM_NEToptim (OsipiBase ):
89 """
@@ -60,7 +61,7 @@ def initialize(self, bounds, initial_guess, fitS0, traindata, SNR):
6061 if SNR is None :
6162 warnings .warn ('No SNR indicated. Data simulated with SNR = (5-1000)' )
6263 SNR = (5 , 1000 )
63- self .osipi_training_data (self .bvalues ,n = 1000000 ,SNR = SNR )
64+ self .training_data (self .bvalues ,n = 1000000 ,SNR = SNR )
6465 self .arg = Arg ()
6566 if bounds is not None :
6667 self .arg .net_pars .cons_min = bounds [0 ] # Dt, Fp, Ds, S0
@@ -95,7 +96,7 @@ def ivim_fit(self, signals, bvalues, **kwargs):
9596 return results
9697
9798
98- def ivim_fit_full_volume (self , signals , bvalues , ** kwargs ):
99+ def ivim_fit_full_volume (self , signals , bvalues , retrain_on_input_data = False , ** kwargs ):
99100 """Perform the IVIM fit
100101
101102 Args:
@@ -107,7 +108,10 @@ def ivim_fit_full_volume(self, signals, bvalues, **kwargs):
107108 """
108109 if not np .array_equal (bvalues , self .bvalues ):
109110 raise ValueError ("bvalue list at fitting must be identical as the one at initiation, otherwise it will not run" )
111+
110112 signals = self .reshape_to_voxelwise (signals )
113+ if retrain_on_input_data :
114+ self .net = deep .learn_IVIM (signals , self .bvalues , self .arg , net = self .net )
111115 paramsNN = deep .predict_IVIM (signals , self .bvalues , self .net , self .arg )
112116
113117 results = {}
@@ -129,6 +133,17 @@ def reshape_to_voxelwise(self, data):
129133 voxels = int (np .prod (data .shape [:- 1 ])) # e.g., X*Y*Z
130134 return data .reshape (voxels , B )
131135
136+
137+ def training_data (self , bvalues , data = None , SNR = (5 ,1000 ), n = 1000000 ,Drange = (0.0005 ,0.0034 ),frange = (0 ,1 ),Dprange = (0.005 ,0.1 ),rician_noise = False ):
138+ rng = np .random .RandomState (42 )
139+ if data is None :
140+ gen = GenerateData (rng = rng )
141+ data , D , f , Dp = gen .simulate_training_data (bvalues , SNR = SNR , n = n ,Drange = Drange ,frange = frange ,Dprange = Dprange ,rician_noise = rician_noise )
142+ if self .supervised :
143+ self .train_data = {'data' :data ,'D' :D ,'f' :f ,'Dp' :Dp }
144+ else :
145+ self .train_data = {'data' : data }
146+
132147class NetArgs :
133148 def __init__ (self ):
134149 self .optim = 'adam' # these are the optimisers implementd. Choices are: 'sgd'; 'sgdr'; 'adagrad' adam
0 commit comments