|
| 1 | +#!/usr/bin/env python |
| 2 | +import argparse |
| 3 | +import itk |
| 4 | +import numpy as np |
| 5 | +from itk import RTK as rtk |
| 6 | + |
| 7 | + |
| 8 | +def build_parser(): |
| 9 | + parser = rtk.RTKArgumentParser( |
| 10 | + description=( |
| 11 | + "Computes expected photon counts from incident spectrum, material " |
| 12 | + "attenuations, detector response and material-decomposed projections" |
| 13 | + ) |
| 14 | + ) |
| 15 | + parser.add_argument( |
| 16 | + "--output", |
| 17 | + "-o", |
| 18 | + help="Output file name (photon counts)", |
| 19 | + type=str, |
| 20 | + required=True, |
| 21 | + ) |
| 22 | + parser.add_argument( |
| 23 | + "--input", "-i", help="Material-decomposed projections", type=str, required=True |
| 24 | + ) |
| 25 | + parser.add_argument( |
| 26 | + "--detector", "-d", help="Detector response file", type=str, required=True |
| 27 | + ) |
| 28 | + parser.add_argument( |
| 29 | + "--incident", help="Incident spectrum file", type=str, required=True |
| 30 | + ) |
| 31 | + parser.add_argument( |
| 32 | + "--attenuations", |
| 33 | + "-a", |
| 34 | + help="Material attenuations file", |
| 35 | + type=str, |
| 36 | + required=True, |
| 37 | + ) |
| 38 | + parser.add_argument( |
| 39 | + "--thresholds", |
| 40 | + "-t", |
| 41 | + help="Lower threshold of bins, expressed in pulse height", |
| 42 | + type=int, |
| 43 | + nargs="+", |
| 44 | + required=True, |
| 45 | + ) |
| 46 | + parser.add_argument( |
| 47 | + "--cramer_rao", help="Output Cramer-Rao lower bound file", type=str |
| 48 | + ) |
| 49 | + parser.add_argument( |
| 50 | + "--variances", help="Output variances of photon counts", type=str |
| 51 | + ) |
| 52 | + return parser |
| 53 | + |
| 54 | + |
| 55 | +def process(args_info: argparse.Namespace): |
| 56 | + PixelValueType = itk.F |
| 57 | + Dimension = 3 |
| 58 | + |
| 59 | + DecomposedProjectionType = itk.VectorImage[PixelValueType, Dimension] |
| 60 | + MeasuredProjectionsType = itk.VectorImage[PixelValueType, Dimension] |
| 61 | + IncidentSpectrumImageType = itk.Image[PixelValueType, Dimension] |
| 62 | + DetectorResponseImageType = itk.Image[PixelValueType, Dimension - 1] |
| 63 | + MaterialAttenuationsImageType = itk.Image[PixelValueType, Dimension - 1] |
| 64 | + |
| 65 | + # Read all inputs |
| 66 | + if args_info.verbose: |
| 67 | + print(f"Reading decomposed projections from {args_info.input}...") |
| 68 | + decomposedProjectionReader = itk.ImageFileReader[DecomposedProjectionType].New() |
| 69 | + decomposedProjectionReader.SetFileName(args_info.input) |
| 70 | + decomposedProjectionReader.Update() |
| 71 | + decomposedProjection = decomposedProjectionReader.GetOutput() |
| 72 | + |
| 73 | + if args_info.verbose: |
| 74 | + print(f"Reading incident spectrum from {args_info.incident}...") |
| 75 | + incidentSpectrumReader = itk.ImageFileReader[IncidentSpectrumImageType].New() |
| 76 | + incidentSpectrumReader.SetFileName(args_info.incident) |
| 77 | + incidentSpectrumReader.Update() |
| 78 | + incidentSpectrum = incidentSpectrumReader.GetOutput() |
| 79 | + |
| 80 | + if args_info.verbose: |
| 81 | + print(f"Reading detector response from {args_info.detector}...") |
| 82 | + detectorResponseReader = itk.ImageFileReader[DetectorResponseImageType].New() |
| 83 | + detectorResponseReader.SetFileName(args_info.detector) |
| 84 | + detectorResponseReader.Update() |
| 85 | + detectorResponse = detectorResponseReader.GetOutput() |
| 86 | + |
| 87 | + if args_info.verbose: |
| 88 | + print(f"Reading material attenuations from {args_info.attenuations}...") |
| 89 | + materialAttenuationsReader = itk.ImageFileReader[ |
| 90 | + MaterialAttenuationsImageType |
| 91 | + ].New() |
| 92 | + materialAttenuationsReader.SetFileName(args_info.attenuations) |
| 93 | + materialAttenuationsReader.Update() |
| 94 | + materialAttenuations = materialAttenuationsReader.GetOutput() |
| 95 | + |
| 96 | + # Get parameters from the images |
| 97 | + NumberOfMaterials = materialAttenuations.GetLargestPossibleRegion().GetSize()[0] |
| 98 | + NumberOfSpectralBins = len(args_info.thresholds) |
| 99 | + MaximumEnergy = incidentSpectrum.GetLargestPossibleRegion().GetSize()[0] |
| 100 | + |
| 101 | + # Generate a set of zero-filled photon count projections |
| 102 | + measuredProjections = MeasuredProjectionsType.New() |
| 103 | + measuredProjections.CopyInformation(decomposedProjection) |
| 104 | + measuredProjections.SetVectorLength(NumberOfSpectralBins) |
| 105 | + measuredProjections.Allocate() |
| 106 | + |
| 107 | + # Read the thresholds on command line |
| 108 | + thresholds = itk.VariableLengthVector[itk.D]() |
| 109 | + thresholds.SetSize(NumberOfSpectralBins + 1) |
| 110 | + for i in range(NumberOfSpectralBins): |
| 111 | + thresholds[i] = args_info.thresholds[i] |
| 112 | + |
| 113 | + # Add the maximum pulse height at the end |
| 114 | + MaximumPulseHeight = detectorResponse.GetLargestPossibleRegion().GetSize()[1] |
| 115 | + thresholds[NumberOfSpectralBins] = MaximumPulseHeight |
| 116 | + |
| 117 | + # Check that the inputs have the expected size |
| 118 | + idx = itk.Index[3]() |
| 119 | + idx.Fill(0) |
| 120 | + if decomposedProjection.GetPixel(idx).Size() != NumberOfMaterials: |
| 121 | + raise RuntimeError( |
| 122 | + f"Decomposed projections vector size {decomposedProjection.GetPixel(idx).Size()} != {NumberOfMaterials}" |
| 123 | + ) |
| 124 | + |
| 125 | + if measuredProjections.GetPixel(idx).Size() != NumberOfSpectralBins: |
| 126 | + raise RuntimeError( |
| 127 | + f"Spectral projections vector size {measuredProjections.GetPixel(idx).Size()} != {NumberOfSpectralBins}" |
| 128 | + ) |
| 129 | + |
| 130 | + if detectorResponse.GetLargestPossibleRegion().GetSize()[0] != MaximumEnergy: |
| 131 | + raise RuntimeError( |
| 132 | + f"Detector response energies {detectorResponse.GetLargestPossibleRegion().GetSize()[0]} != {MaximumEnergy}" |
| 133 | + ) |
| 134 | + |
| 135 | + # Create and set the filter |
| 136 | + forward = rtk.SpectralForwardModelImageFilter[ |
| 137 | + DecomposedProjectionType, |
| 138 | + MeasuredProjectionsType, |
| 139 | + IncidentSpectrumImageType, |
| 140 | + DetectorResponseImageType, |
| 141 | + MaterialAttenuationsImageType, |
| 142 | + ].New() |
| 143 | + forward.SetInputDecomposedProjections(decomposedProjection) |
| 144 | + forward.SetInputMeasuredProjections(measuredProjections) |
| 145 | + forward.SetInputIncidentSpectrum(incidentSpectrum) |
| 146 | + forward.SetDetectorResponse(detectorResponse) |
| 147 | + forward.SetMaterialAttenuations(materialAttenuations) |
| 148 | + forward.SetThresholds(thresholds) |
| 149 | + if args_info.cramer_rao: |
| 150 | + forward.SetComputeCramerRaoLowerBound(True) |
| 151 | + if args_info.variances: |
| 152 | + forward.SetComputeVariances(True) |
| 153 | + |
| 154 | + if args_info.verbose: |
| 155 | + print("Running spectral forward model...") |
| 156 | + forward.Update() |
| 157 | + |
| 158 | + # Inspect output and convert if necessary so we write a 3D VectorImage |
| 159 | + out = forward.GetOutput() |
| 160 | + if args_info.verbose: |
| 161 | + print("Forward output type:", type(out)) |
| 162 | + print("ImageDimension:", out.GetImageDimension()) |
| 163 | + print("ComponentsPerPixel:", out.GetNumberOfComponentsPerPixel()) |
| 164 | + print("Size:", tuple(out.GetLargestPossibleRegion().GetSize())) |
| 165 | + |
| 166 | + # If filter produced a 4D scalar image (dim=4, comps=1), convert it to a 3D VectorImage |
| 167 | + if out.GetImageDimension() == 4 and out.GetNumberOfComponentsPerPixel() == 1: |
| 168 | + if args_info.verbose: |
| 169 | + print( |
| 170 | + "Converting 4D scalar image to 3D VectorImage (components->vector)..." |
| 171 | + ) |
| 172 | + arr4 = itk.array_view_from_image(out) |
| 173 | + # Move the first axis (extra spatial dim) to the last axis to get (Z,Y,X,Components) |
| 174 | + arr_vec = np.moveaxis(arr4, 0, -1) |
| 175 | + arr_vec = arr_vec.astype(np.float32) |
| 176 | + vec_img = itk.image_from_array(arr_vec, is_vector=True) |
| 177 | + # copy spatial metadata from original (origin, spacing, direction may need adjustment) |
| 178 | + vec_img.SetOrigin(out.GetOrigin()[:3]) |
| 179 | + vec_img.SetSpacing(out.GetSpacing()[:3]) |
| 180 | + try: |
| 181 | + vec_img.SetDirection(out.GetDirection()[0:3, 0:3]) |
| 182 | + except Exception: |
| 183 | + # some ITK builds may not support slicing direction; ignore if unavailable |
| 184 | + pass |
| 185 | + writer_input = vec_img |
| 186 | + else: |
| 187 | + writer_input = out |
| 188 | + |
| 189 | + if args_info.verbose: |
| 190 | + print(f"Writing output photon counts to {args_info.output}...") |
| 191 | + WriterType = itk.ImageFileWriter[MeasuredProjectionsType] |
| 192 | + writer = WriterType.New() |
| 193 | + writer.SetFileName(args_info.output) |
| 194 | + writer.SetInput(writer_input) |
| 195 | + writer.SetImageIO(itk.MetaImageIO.New()) |
| 196 | + writer.Update() |
| 197 | + |
| 198 | + # If requested, write the Cramer-Rao lower bound |
| 199 | + if args_info.cramer_rao: |
| 200 | + if args_info.verbose: |
| 201 | + print(f"Writing Cramer-Rao lower bound to {args_info.cramer_rao}...") |
| 202 | + # Cramer-Rao output has same type as measured projections |
| 203 | + cramer_writer = WriterType.New() |
| 204 | + cramer_writer.SetFileName(args_info.cramer_rao) |
| 205 | + cramer_writer.SetInput(forward.GetOutputCramerRaoLowerBound()) |
| 206 | + cramer_writer.SetImageIO(itk.MetaImageIO.New()) |
| 207 | + cramer_writer.Update() |
| 208 | + |
| 209 | + # If requested, write the variance |
| 210 | + if args_info.variances: |
| 211 | + if args_info.verbose: |
| 212 | + print(f"Writing variances to {args_info.variances}...") |
| 213 | + var_writer = WriterType.New() |
| 214 | + var_writer.SetFileName(args_info.variances) |
| 215 | + var_writer.SetInput(forward.GetOutputVariances()) |
| 216 | + var_writer.SetImageIO(itk.MetaImageIO.New()) |
| 217 | + var_writer.Update() |
| 218 | + |
| 219 | + |
| 220 | +def main(argv=None): |
| 221 | + parser = build_parser() |
| 222 | + args_info = parser.parse_args(argv) |
| 223 | + process(args_info) |
| 224 | + |
| 225 | + |
| 226 | +if __name__ == "__main__": |
| 227 | + main() |
0 commit comments