Skip to content

Commit 19eb2b2

Browse files
axel-grcCopilot
andcommitted
ENH: Add rtkspectralonestep python application
Co-authored-by: Copilot <copilot@github.com>
1 parent 8f46b47 commit 19eb2b2

4 files changed

Lines changed: 319 additions & 0 deletions

File tree

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
#!/usr/bin/env python
2+
import argparse
3+
import itk
4+
import numpy as np
5+
from itk import RTK as rtk
6+
7+
8+
def build_parser():
9+
parser = rtk.RTKArgumentParser(
10+
description="One-step spectral reconstruction (Python port)"
11+
)
12+
parser.add_argument("--geometry", "-g", required=True, help="XML geometry file name")
13+
parser.add_argument("--output", "-o", required=True, help="Output file name")
14+
parser.add_argument("--niterations", "-n", type=int, default=1, help="Number of iterations")
15+
parser.add_argument("--input", "-i", help="Material volumes initial guess")
16+
parser.add_argument(
17+
"--spectral", "-s", required=True, help="Spectral projections, i.e. photon counts"
18+
)
19+
parser.add_argument("--detector", "-d", required=True, help="Detector response file")
20+
parser.add_argument("--incident", required=True, help="Incident spectrum file (mhd image)")
21+
parser.add_argument(
22+
"--attenuations", "-a", required=True, help="Material attenuations file"
23+
)
24+
parser.add_argument("--mask", "-m", help="Apply a support binary mask: reconstruction kept null outside", default=None)
25+
parser.add_argument(
26+
"--regul_spatial_weights",
27+
help="Spatial regularization weights file",
28+
default=None,
29+
)
30+
parser.add_argument(
31+
"--projection_weights", help="Projection weights file", default=None
32+
)
33+
parser.add_argument(
34+
"--thresholds",
35+
"-t",
36+
type=float,
37+
nargs="+",
38+
required=True,
39+
help="Lower threshold of bins, expressed in pulse height",
40+
)
41+
parser.add_argument("--subsets", type=int, default=1, help="Number of subsets of projections (should not exceed 6)")
42+
parser.add_argument(
43+
"--regul_weights",
44+
type=float,
45+
nargs="+",
46+
help="Regularization parameters for each material",
47+
)
48+
parser.add_argument(
49+
"--regul_radius", type=int, nargs="+", help="Radius of the neighborhood for regularization"
50+
)
51+
parser.add_argument("--reset_nesterov", type=int, default=1, help="Reset Nesterov after a number of subsets")
52+
53+
rtk.add_rtkiterations_group(parser)
54+
rtk.add_rtkprojectors_group(parser)
55+
56+
return parser
57+
58+
59+
def GetFileHeader(filename: str):
60+
io = itk.ImageIOFactory.CreateImageIO(filename, itk.CommonEnums.IOFileMode_ReadMode)
61+
if not io:
62+
raise RuntimeError(f"ImageIOFactory could not create an ImageIO for '{filename}'")
63+
io.SetFileName(filename)
64+
io.ReadImageInformation()
65+
return io
66+
67+
68+
def spectral_bin_detector_response(drm_img, thresholds):
69+
# drm_img: itk image 2D (energies, pulseHeights)
70+
region = drm_img.GetLargestPossibleRegion()
71+
size = region.GetSize()
72+
numberOfEnergies = size[0]
73+
74+
numberOfSpectralBins = len(thresholds) - 1
75+
76+
binnedResponse = np.zeros((numberOfSpectralBins, numberOfEnergies), dtype=float)
77+
78+
indexDet = itk.Index[2]()
79+
for energy in range(numberOfEnergies):
80+
indexDet[0] = energy
81+
for bin in range(numberOfSpectralBins):
82+
# First and last couple of values:
83+
# use trapezoidal rule with linear interpolation
84+
infPulse = int(np.floor(thresholds[bin]))
85+
if infPulse < 1:
86+
raise RuntimeError(f"Threshold {thresholds[bin]} below 0 keV")
87+
88+
supPulse = int(np.floor(thresholds[bin + 1]))
89+
if float(supPulse) == thresholds[bin + 1]:
90+
supPulse -= 1
91+
92+
if supPulse - infPulse < 3:
93+
raise RuntimeError("Thresholds are too close for the current code.")
94+
95+
wInf = infPulse + 1.0 - thresholds[bin]
96+
indexDet[1] = infPulse - 1 # Index 0 is 1 keV
97+
binnedResponse[bin, energy] += 0.5 * wInf * wInf * drm_img.GetPixel(indexDet)
98+
99+
indexDet[1] += 1
100+
binnedResponse[bin, energy] += 0.5 * (1.0 + wInf * (2.0 - wInf)) * drm_img.GetPixel(indexDet)
101+
102+
wSup = thresholds[bin + 1] - supPulse
103+
indexDet[1] = supPulse #Index 0 is 1 keV
104+
if supPulse >= drm_img.GetLargestPossibleRegion().GetSize()[1]:
105+
raise RuntimeError(
106+
f"Threshold {thresholds[bin+1]} above max {drm_img.GetLargestPossibleRegion().GetSize()[1]}"
107+
)
108+
binnedResponse[bin, energy] += 0.5 * wSup * wSup * drm_img.GetPixel(indexDet)
109+
110+
indexDet[1] -= 1
111+
binnedResponse[bin, energy] += 0.5 * (1.0 + wSup * (2.0 - wSup)) * drm_img.GetPixel(indexDet)
112+
113+
# Intermediate values
114+
for pulseHeight in range(infPulse + 1, supPulse - 1):
115+
indexDet[1] = pulseHeight
116+
binnedResponse[bin, energy] += drm_img.GetPixel(indexDet)
117+
rows, cols = binnedResponse.shape
118+
v = itk.vnl_matrix[itk.F](rows, cols)
119+
for i in range(rows):
120+
row = binnedResponse[i]
121+
for j in range(cols):
122+
v.put(i, j, row[j])
123+
return v
124+
125+
126+
def process(args_info: argparse.Namespace):
127+
dataType = itk.F
128+
Dimension = 3
129+
130+
headerInputMeasuredProjections = GetFileHeader(args_info.spectral)
131+
headerAttenuations = GetFileHeader(args_info.attenuations)
132+
nBins = headerInputMeasuredProjections.GetNumberOfComponents()
133+
nMaterials = headerAttenuations.GetDimensions(0)
134+
135+
# Define types for the input images
136+
MeasuredProjectionsType = itk.Image[
137+
itk.Vector[dataType, nBins], Dimension
138+
]
139+
MaterialVolumesType = itk.Image[
140+
itk.Vector[dataType, nMaterials], Dimension
141+
]
142+
IncidentSpectrumType = itk.Image[dataType, Dimension]
143+
DetectorResponseType = itk.Image[dataType, Dimension - 1]
144+
MaterialAttenuationsType = itk.Image[dataType, Dimension - 1]
145+
146+
# Instantiate and update the readers
147+
mea = itk.ImageFileReader[MeasuredProjectionsType].New()
148+
mea.SetFileName(args_info.spectral)
149+
mea.Update()
150+
mea = mea.GetOutput()
151+
152+
incidentSpectrum = itk.ImageFileReader[IncidentSpectrumType].New()
153+
incidentSpectrum.SetFileName(args_info.incident)
154+
incidentSpectrum.Update()
155+
incidentSpectrum = incidentSpectrum.GetOutput()
156+
157+
detectorResponse = itk.ImageFileReader[DetectorResponseType].New()
158+
detectorResponse.SetFileName(args_info.detector)
159+
detectorResponse.Update()
160+
detectorResponse = detectorResponse.GetOutput()
161+
162+
materialAttenuations = itk.ImageFileReader[MaterialAttenuationsType].New()
163+
materialAttenuations.SetFileName(args_info.attenuations)
164+
materialAttenuations.Update()
165+
materialAttenuations = materialAttenuations.GetOutput()
166+
167+
# Read Support Mask if given
168+
if args_info.mask:
169+
supportmask = itk.imread(args_info.mask)
170+
171+
# Read spatial regularization weights if given
172+
if args_info.regul_spatial_weights:
173+
spatialRegulWeighs = itk.imread(args_info.regul_spatial_weights)
174+
175+
#Read projections weights if given
176+
if args_info.projection_weights:
177+
projectionWeights = itk.imread(args_info.projection_weights)
178+
179+
# Create input: either an existing volume read from a file or a blank image
180+
if args_info.input is not None:
181+
input = itk.ImageFileReader[MaterialVolumesType].New()
182+
input.SetFileName(args_info.input)
183+
input.Update()
184+
input = input.GetOutput()
185+
else:
186+
constantImageSource = itk.ConstantImageSource[MaterialVolumesType].New()
187+
rtk.SetConstantImageSourceFromArgParse(constantImageSource, args_info)
188+
input = constantImageSource.GetOutput()
189+
190+
# Read the material attenuations image as a matrix (C++ style)
191+
indexMat = itk.Index[2]()
192+
nEnergies = materialAttenuations.GetLargestPossibleRegion().GetSize()[1]
193+
materialAttenuationsMatrix = itk.vnl_matrix[itk.F](nEnergies, nMaterials)
194+
for energy in range(nEnergies):
195+
indexMat[1] = energy
196+
for material in range(nMaterials):
197+
indexMat[0] = material
198+
materialAttenuationsMatrix.put(energy, material, materialAttenuations.GetPixel(indexMat))
199+
200+
thresholds = list(args_info.thresholds)
201+
MaximumPulseHeight = detectorResponse.GetLargestPossibleRegion().GetSize()[1]
202+
thresholds.append(MaximumPulseHeight)
203+
if len(thresholds) - 1 != nBins:
204+
raise RuntimeError(f"Number of thresholds {len(thresholds) - 1} does not match the number of bins {nBins}")
205+
206+
# Read the detector response image as a matrix, and bin it
207+
drm = spectral_bin_detector_response(detectorResponse, thresholds)
208+
209+
# Geometry
210+
if args_info.verbose:
211+
print(f"Reading geometry from {args_info.geometry} ...")
212+
geometry = rtk.read_geometry(args_info.geometry)
213+
214+
# Read the regularization parameters
215+
regulRadius = itk.Size[3]()
216+
if args_info.regul_radius:
217+
for i in range(3):
218+
regulRadius[i] = args_info.regul_radius[min(i, len(args_info.regul_radius) - 1)]
219+
else:
220+
regulRadius.Fill(0)
221+
222+
regulWeights = itk.Vector[dataType, nMaterials]()
223+
if args_info.regul_weights:
224+
for i in range(nMaterials):
225+
regulWeights[i] = args_info.regul_weights[min(i, len(args_info.regul_weights) - 1)]
226+
else:
227+
regulWeights.Fill(0.)
228+
229+
if hasattr(itk, "CudaImage"):
230+
CudaMeasuredProjectionsType = itk.CudaImage[
231+
itk.Vector[dataType, nBins], Dimension
232+
]
233+
CudaMaterialVolumesType = itk.CudaImage[
234+
itk.Vector[dataType, nMaterials], Dimension
235+
]
236+
CudaIncidentSpectrumType = itk.CudaImage[dataType, Dimension]
237+
238+
mechlemOneStep = rtk.MechlemOneStepSpectralReconstructionFilter[
239+
CudaMaterialVolumesType, CudaMeasuredProjectionsType, CudaIncidentSpectrumType
240+
].New()
241+
242+
mechlemOneStep.SetInputMaterialVolumes(itk.cuda_image_from_image(input))
243+
mechlemOneStep.SetInputIncidentSpectrum(itk.cuda_image_from_image(incidentSpectrum))
244+
mechlemOneStep.SetInputMeasuredProjections(itk.cuda_image_from_image(mea))
245+
if args_info.mask:
246+
mechlemOneStep.SetSupportMask(itk.cuda_image_from_image(supportmask))
247+
if args_info.regul_spatial_weights:
248+
mechlemOneStep.SetSpatialRegularizationWeights(itk.cuda_image_from_image(spatialRegulWeighs))
249+
if args_info.projection_weights:
250+
mechlemOneStep.SetProjectionWeights(itk.cuda_image_from_image(projectionWeights))
251+
else:
252+
mechlemOneStep = rtk.MechlemOneStepSpectralReconstructionFilter[
253+
MaterialVolumesType, MeasuredProjectionsType, IncidentSpectrumType
254+
].New()
255+
256+
mechlemOneStep.SetInputMaterialVolumes(input)
257+
mechlemOneStep.SetInputIncidentSpectrum(incidentSpectrum)
258+
mechlemOneStep.SetInputMeasuredProjections(mea)
259+
if args_info.mask:
260+
mechlemOneStep.SetSupportMask(supportmask)
261+
if args_info.regul_spatial_weights:
262+
mechlemOneStep.SetSpatialRegularizationWeights(spatialRegulWeighs)
263+
if args_info.projection_weights:
264+
mechlemOneStep.SetProjectionWeights(projectionWeights)
265+
266+
rtk.SetIterationsReportFromArgParse(args_info, mechlemOneStep)
267+
rtk.SetForwardProjectionFromArgParse(args_info, mechlemOneStep)
268+
rtk.SetBackProjectionFromArgParse(args_info, mechlemOneStep)
269+
270+
mechlemOneStep.SetBinnedDetectorResponse(drm)
271+
mechlemOneStep.SetMaterialAttenuations(materialAttenuationsMatrix)
272+
mechlemOneStep.SetNumberOfIterations(args_info.niterations)
273+
mechlemOneStep.SetNumberOfSubsets(args_info.subsets)
274+
mechlemOneStep.SetRegularizationRadius(regulRadius)
275+
mechlemOneStep.SetRegularizationWeights(regulWeights)
276+
if args_info.reset_nesterov:
277+
mechlemOneStep.SetResetNesterovEvery(args_info.reset_nesterov)
278+
mechlemOneStep.SetGeometry(geometry)
279+
280+
mechlemOneStep.Update()
281+
282+
# Write output
283+
WriterType = itk.ImageFileWriter[MaterialVolumesType]
284+
writer = WriterType.New()
285+
writer.SetFileName(args_info.output)
286+
writer.SetInput(mechlemOneStep.GetOutput())
287+
writer.SetImageIO(itk.MetaImageIO.New())
288+
writer.Update()
289+
290+
291+
def main(argv=None):
292+
parser = build_parser()
293+
args_info = parser.parse_args(argv)
294+
process(args_info)
295+
296+
297+
if __name__ == "__main__":
298+
main()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ rtksart = "itk.rtksart:main"
8383
rtksimulatedgeometry = "itk.rtksimulatedgeometry:main"
8484
rtkspectralforwardmodel = "itk.rtkspectralforwardmodel:main"
8585
rtkspectraldenoiseprojections = "itk.rtkspectraldenoiseprojections:main"
86+
rtkspectralonestep = "itk.rtkspectralonestep:main"
8687
rtkspectralrooster = "itk.rtkspectralrooster:main"
8788
rtksubselect = "itk.rtksubselect:main"
8889
rtktotalvariationdenoising = "itk.rtktotalvariationdenoising:main"

wrapping/__init_rtk__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
"rtksimulatedgeometry",
6363
"rtkspectralforwardmodel",
6464
"rtkspectraldenoiseprojections",
65+
"rtkspectralonestep",
6566
"rtkspectralrooster",
6667
"rtksubselect",
6768
"rtktotalnuclearvariationdenoising",
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
itk_wrap_include(itkImage.h)
2+
3+
itk_wrap_class("itk::CudaImageFromImageFilter" POINTER)
4+
5+
UNIQUE(vector_types "${WRAP_ITK_VECTOR_REAL};${WRAP_ITK_COV_VECTOR_REAL}")
6+
set(vectorComponents 2 3 4 5)
7+
foreach(c ${vectorComponents})
8+
list(FIND ITK_WRAP_VECTOR_COMPONENTS "${c}" _index)
9+
if(${_index} EQUAL -1)
10+
UNIQUE(imageDimensions "${ITK_WRAP_IMAGE_DIMS}")
11+
foreach(d ${imageDimensions})
12+
foreach(vt ${vector_types})
13+
itk_wrap_template("${ITKM_${vt}${c}}${d}" "itk::Image<${ITKT_${vt}${c}}, ${d}>")
14+
endforeach()
15+
endforeach()
16+
endif()
17+
endforeach()
18+
19+
itk_end_wrap_class()

0 commit comments

Comments
 (0)