Skip to content

Commit 34b8a15

Browse files
moved training data generation
1 parent abbd631 commit 34b8a15

2 files changed

Lines changed: 17 additions & 14 deletions

File tree

src/standardized/IVIM_NEToptim.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import IVIMNET.deep as deep
44
import torch
55
import warnings
6+
from utilities.data_simulation.GenerateData import GenerateData
67

78
class IVIM_NEToptim(OsipiBase):
89
"""
@@ -60,7 +61,7 @@ def initialize(self, bounds, initial_guess, fitS0, traindata, SNR):
6061
if SNR is None:
6162
warnings.warn('No SNR indicated. Data simulated with SNR = (5-1000)')
6263
SNR = (5, 1000)
63-
self.osipi_training_data(self.bvalues,n=1000000,SNR=SNR)
64+
self.training_data(self.bvalues,n=1000000,SNR=SNR)
6465
self.arg=Arg()
6566
if bounds is not None:
6667
self.arg.net_pars.cons_min = bounds[0] # Dt, Fp, Ds, S0
@@ -95,7 +96,7 @@ def ivim_fit(self, signals, bvalues, **kwargs):
9596
return results
9697

9798

98-
def ivim_fit_full_volume(self, signals, bvalues, **kwargs):
99+
def ivim_fit_full_volume(self, signals, bvalues, retrain_on_input_data=False, **kwargs):
99100
"""Perform the IVIM fit
100101
101102
Args:
@@ -107,7 +108,10 @@ def ivim_fit_full_volume(self, signals, bvalues, **kwargs):
107108
"""
108109
if not np.array_equal(bvalues, self.bvalues):
109110
raise ValueError("bvalue list at fitting must be identical as the one at initiation, otherwise it will not run")
111+
110112
signals = self.reshape_to_voxelwise(signals)
113+
if retrain_on_input_data:
114+
self.net = deep.learn_IVIM(signals, self.bvalues, self.arg, net=self.net)
111115
paramsNN = deep.predict_IVIM(signals, self.bvalues, self.net, self.arg)
112116

113117
results = {}
@@ -129,6 +133,17 @@ def reshape_to_voxelwise(self, data):
129133
voxels = int(np.prod(data.shape[:-1])) # e.g., X*Y*Z
130134
return data.reshape(voxels, B)
131135

136+
137+
def training_data(self, bvalues, data=None, SNR=(5,1000), n=1000000,Drange=(0.0005,0.0034),frange=(0,1),Dprange=(0.005,0.1),rician_noise=False):
138+
rng = np.random.RandomState(42)
139+
if data is None:
140+
gen = GenerateData(rng=rng)
141+
data, D, f, Dp = gen.simulate_training_data(bvalues, SNR=SNR, n=n,Drange=Drange,frange=frange,Dprange=Dprange,rician_noise=rician_noise)
142+
if self.supervised:
143+
self.train_data = {'data':data,'D':D,'f':f,'Dp':Dp}
144+
else:
145+
self.train_data = {'data': data}
146+
132147
class NetArgs:
133148
def __init__(self):
134149
self.optim = 'adam' # these are the optimisers implementd. Choices are: 'sgd'; 'sgdr'; 'adagrad' adam

src/wrappers/OsipiBase.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pathlib
55
import sys
66
from tqdm import tqdm
7-
from utilities.data_simulation.GenerateData import GenerateData
87

98

109
class OsipiBase:
@@ -319,14 +318,3 @@ def osipi_simple_bias_and_RMSE_test(self, SNR, bvalues, f, Dstar, D, noise_reali
319318
print(f"Dstar bias:\t{Dstar_bias}\nDstar RMSE:\t{Dstar_RMSE}")
320319
print(f"D bias:\t{D_bias}\nD RMSE:\t{D_RMSE}")
321320

322-
def osipi_training_data(self, bvalues, data=None, SNR=(5,1000), n=1000000,Drange=(0.0005,0.0034),frange=(0,1),Dprange=(0.005,0.1),rician_noise=False):
323-
rng = np.random.RandomState(42)
324-
if data is None:
325-
gen = GenerateData(rng=rng)
326-
data, D, f, Dp = gen.simulate_training_data(bvalues, SNR=SNR, n=n,Drange=Drange,frange=frange,Dprange=Dprange,rician_noise=rician_noise)
327-
if self.supervised:
328-
self.train_data = {'data':data,'D':D,'f':f,'Dp':Dp}
329-
else:
330-
self.train_data = {'data': data}
331-
332-

0 commit comments

Comments
 (0)